diff --git a/pystencils_walberla/codegen.py b/pystencils_walberla/codegen.py index 7cf34bedc26b269900ecaeb14ddf5618df9644bf..ed989c892d0ec70f95e7aa5dc5ef5113904b712e 100644 --- a/pystencils_walberla/codegen.py +++ b/pystencils_walberla/codegen.py @@ -6,6 +6,7 @@ from jinja2 import Environment, PackageLoader from pystencils import ( Assignment, AssignmentCollection, Field, FieldType, create_kernel, create_staggered_kernel) +from pystencils.astnodes import KernelFunction from pystencils.backends.cbackend import get_headers from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.stencil import inverse_direction, offset_to_direction_string @@ -29,7 +30,7 @@ def generate_sweep(generation_context, class_name, assignments, defines where to write generated files, if OpenMP is available or which SIMD instruction set should be used. See waLBerla examples on how to get a context. class_name: name of the generated sweep class - assignments: list of assignments defining the stencil update rule + assignments: list of assignments defining the stencil update rule or a :class:`KernelFunction` namespace: the generated class is accessible as walberla::<namespace>::<class_name> field_swaps: sequence of field pairs (field, temporary_field). The generated sweep only gets the first field as argument, creating a temporary field internally which is swapped with the first field after @@ -48,7 +49,10 @@ def generate_sweep(generation_context, class_name, assignments, if not generation_context.cuda and create_kernel_params['target'] == 'gpu': return - if not staggered: + if isinstance(assignments, KernelFunction): + ast = assignments + create_kernel_params['target'] = ast.target + elif not staggered: ast = create_kernel(assignments, **create_kernel_params) else: ast = create_staggered_kernel(assignments, **create_kernel_params)