From a5e5f2c3ffb9aa5d228b155942cf5ad836363c77 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Wed, 31 Oct 2018 18:05:18 +0100
Subject: [PATCH] walberla boundary class now also works for GPU boundaries

---
 jinja_filters.py | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/jinja_filters.py b/jinja_filters.py
index bb72cd2..8bc31cd 100644
--- a/jinja_filters.py
+++ b/jinja_filters.py
@@ -111,7 +111,7 @@ def field_extraction_code(field_accesses, field_name, is_temporary, declaration_
             return "%s * %s;" % (field_type, field_name)
         else:
             prefix = "" if no_declaration else "auto "
-            return "%s%s = block->uncheckedFastGetData< %s >(%sID);" % (prefix, field_name, field_type, field_name)
+            return "%s%s = block->getData< %s >(%sID);" % (prefix, field_name, field_type, field_name)
     else:
         assert field_name.endswith('_tmp')
         original_field_name = field_name[:-len('_tmp')]
@@ -170,7 +170,7 @@ def generate_refs_for_kernel_parameters(kernel_info, prefix, parameters_to_ignor
 
 
 @jinja2.contextfilter
-def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0'):
+def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0', spatial_shape_symbols=[]):
     """Generates the function call to a pystencils kernel
 
     Args:
@@ -290,11 +290,14 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non
             kernel_call_lines += create_field_shape_code(fields[param.field_name], param.name)
 
     if not is_cpu:
-        spatial_shape_symbols = []
-        for param in ast_params:
-            if param.is_field_shape_argument:
-                spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE))
-                                         for i in range(field.spatial_dimensions)]
+        if not spatial_shape_symbols:
+            spatial_shape_symbols = []
+            for param in ast_params:
+                if param.is_field_shape_argument:
+                    spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE))
+                                             for i in range(field.spatial_dimensions)]
+        else:
+            spatial_shape_symbols = [TypedSymbol(e, get_base_type(Field.SHAPE_DTYPE)) for e in spatial_shape_symbols]
 
         if not spatial_shape_symbols:
             for param in ast_params:
-- 
GitLab