From a5e5f2c3ffb9aa5d228b155942cf5ad836363c77 Mon Sep 17 00:00:00 2001 From: Martin Bauer <martin.bauer@fau.de> Date: Wed, 31 Oct 2018 18:05:18 +0100 Subject: [PATCH] walberla boundary class now also works for GPU boundaries --- jinja_filters.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/jinja_filters.py b/jinja_filters.py index bb72cd2..8bc31cd 100644 --- a/jinja_filters.py +++ b/jinja_filters.py @@ -111,7 +111,7 @@ def field_extraction_code(field_accesses, field_name, is_temporary, declaration_ return "%s * %s;" % (field_type, field_name) else: prefix = "" if no_declaration else "auto " - return "%s%s = block->uncheckedFastGetData< %s >(%sID);" % (prefix, field_name, field_type, field_name) + return "%s%s = block->getData< %s >(%sID);" % (prefix, field_name, field_type, field_name) else: assert field_name.endswith('_tmp') original_field_name = field_name[:-len('_tmp')] @@ -170,7 +170,7 @@ def generate_refs_for_kernel_parameters(kernel_info, prefix, parameters_to_ignor @jinja2.contextfilter -def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0'): +def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0', spatial_shape_symbols=[]): """Generates the function call to a pystencils kernel Args: @@ -290,11 +290,14 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non kernel_call_lines += create_field_shape_code(fields[param.field_name], param.name) if not is_cpu: - spatial_shape_symbols = [] - for param in ast_params: - if param.is_field_shape_argument: - spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE)) - for i in range(field.spatial_dimensions)] + if not spatial_shape_symbols: + spatial_shape_symbols = [] + for param in ast_params: + if param.is_field_shape_argument: + spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE)) + for i in range(field.spatial_dimensions)] + else: + spatial_shape_symbols = [TypedSymbol(e, get_base_type(Field.SHAPE_DTYPE)) for e in spatial_shape_symbols] if not spatial_shape_symbols: for param in ast_params: -- GitLab