From d482fd91048e998cfbc9426d6e0dc3645052ba2a Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 28 Jan 2020 15:42:59 +0100
Subject: [PATCH] Allow construction of sweeps from ASTs

---
 pystencils_walberla/codegen.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/pystencils_walberla/codegen.py b/pystencils_walberla/codegen.py
index 7cf34be..ed989c8 100644
--- a/pystencils_walberla/codegen.py
+++ b/pystencils_walberla/codegen.py
@@ -6,6 +6,7 @@ from jinja2 import Environment, PackageLoader
 
 from pystencils import (
     Assignment, AssignmentCollection, Field, FieldType, create_kernel, create_staggered_kernel)
+from pystencils.astnodes import KernelFunction
 from pystencils.backends.cbackend import get_headers
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
 from pystencils.stencil import inverse_direction, offset_to_direction_string
@@ -29,7 +30,7 @@ def generate_sweep(generation_context, class_name, assignments,
                             defines where to write generated files, if OpenMP is available or which SIMD instruction
                             set should be used. See waLBerla examples on how to get a context.
         class_name: name of the generated sweep class
-        assignments: list of assignments defining the stencil update rule
+        assignments: list of assignments defining the stencil update rule or a :class:`KernelFunction`
         namespace: the generated class is accessible as walberla::<namespace>::<class_name>
         field_swaps: sequence of field pairs (field, temporary_field). The generated sweep only gets the first field
                      as argument, creating a temporary field internally which is swapped with the first field after
@@ -48,7 +49,10 @@ def generate_sweep(generation_context, class_name, assignments,
     if not generation_context.cuda and create_kernel_params['target'] == 'gpu':
         return
 
-    if not staggered:
+    if isinstance(assignments, KernelFunction):
+        ast = assignments
+        create_kernel_params['target'] = ast.target
+    elif not staggered:
         ast = create_kernel(assignments, **create_kernel_params)
     else:
         ast = create_staggered_kernel(assignments, **create_kernel_params)
-- 
GitLab