diff --git a/pystencils_walberla/codegen.py b/pystencils_walberla/codegen.py
index 7cf34bedc26b269900ecaeb14ddf5618df9644bf..ed989c892d0ec70f95e7aa5dc5ef5113904b712e 100644
--- a/pystencils_walberla/codegen.py
+++ b/pystencils_walberla/codegen.py
@@ -6,6 +6,7 @@ from jinja2 import Environment, PackageLoader
 
 from pystencils import (
     Assignment, AssignmentCollection, Field, FieldType, create_kernel, create_staggered_kernel)
+from pystencils.astnodes import KernelFunction
 from pystencils.backends.cbackend import get_headers
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
 from pystencils.stencil import inverse_direction, offset_to_direction_string
@@ -29,7 +30,7 @@ def generate_sweep(generation_context, class_name, assignments,
                             defines where to write generated files, if OpenMP is available or which SIMD instruction
                             set should be used. See waLBerla examples on how to get a context.
         class_name: name of the generated sweep class
-        assignments: list of assignments defining the stencil update rule
+        assignments: list of assignments defining the stencil update rule or a :class:`KernelFunction`
         namespace: the generated class is accessible as walberla::<namespace>::<class_name>
         field_swaps: sequence of field pairs (field, temporary_field). The generated sweep only gets the first field
                      as argument, creating a temporary field internally which is swapped with the first field after
@@ -48,7 +49,10 @@ def generate_sweep(generation_context, class_name, assignments,
     if not generation_context.cuda and create_kernel_params['target'] == 'gpu':
         return
 
-    if not staggered:
+    if isinstance(assignments, KernelFunction):
+        ast = assignments
+        create_kernel_params['target'] = ast.target
+    elif not staggered:
         ast = create_kernel(assignments, **create_kernel_params)
     else:
         ast = create_staggered_kernel(assignments, **create_kernel_params)