diff --git a/jinja_filters.py b/jinja_filters.py index bb72cd20a55b6b94e3f91d58834443926750301a..8bc31cdf8c501043d25348a55e3ead38d9bf0985 100644 --- a/jinja_filters.py +++ b/jinja_filters.py @@ -111,7 +111,7 @@ def field_extraction_code(field_accesses, field_name, is_temporary, declaration_ return "%s * %s;" % (field_type, field_name) else: prefix = "" if no_declaration else "auto " - return "%s%s = block->uncheckedFastGetData< %s >(%sID);" % (prefix, field_name, field_type, field_name) + return "%s%s = block->getData< %s >(%sID);" % (prefix, field_name, field_type, field_name) else: assert field_name.endswith('_tmp') original_field_name = field_name[:-len('_tmp')] @@ -170,7 +170,7 @@ def generate_refs_for_kernel_parameters(kernel_info, prefix, parameters_to_ignor @jinja2.contextfilter -def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0'): +def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=None, stream='0', spatial_shape_symbols=[]): """Generates the function call to a pystencils kernel Args: @@ -290,11 +290,14 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non kernel_call_lines += create_field_shape_code(fields[param.field_name], param.name) if not is_cpu: - spatial_shape_symbols = [] - for param in ast_params: - if param.is_field_shape_argument: - spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE)) - for i in range(field.spatial_dimensions)] + if not spatial_shape_symbols: + spatial_shape_symbols = [] + for param in ast_params: + if param.is_field_shape_argument: + spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE)) + for i in range(field.spatial_dimensions)] + else: + spatial_shape_symbols = [TypedSymbol(e, get_base_type(Field.SHAPE_DTYPE)) for e in spatial_shape_symbols] if not spatial_shape_symbols: for param in ast_params: