diff --git a/jinja_filters.py b/jinja_filters.py index 8e82794a6af7cc42e65f77f2a5a2eb20d54a584b..bb72cd20a55b6b94e3f91d58834443926750301a 100644 --- a/jinja_filters.py +++ b/jinja_filters.py @@ -1,6 +1,8 @@ import sympy as sp import jinja2 import copy + +from pystencils import TypedSymbol from pystencils.astnodes import ResolvedFieldAccess from pystencils.data_types import get_base_type from pystencils.backends.cbackend import generate_c, CustomSympyPrinter @@ -291,7 +293,7 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non spatial_shape_symbols = [] for param in ast_params: if param.is_field_shape_argument: - spatial_shape_symbols = [sp.Symbol("%s_cpu[%d]" % (param.name, i)) + spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE)) for i in range(field.spatial_dimensions)] if not spatial_shape_symbols: @@ -301,7 +303,8 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non field = fields[param.field_name] if field.field_type == FieldType.GENERIC: kernel_call_lines += create_field_shape_code(field, '_size', gpu_copy=False) - spatial_shape_symbols = [sp.Symbol("_size[%d]" % (i, )) for i in range(field.spatial_dimensions)] + spatial_shape_symbols = [TypedSymbol("_size[%d]" % (i, ), get_base_type(Field.SHAPE_DTYPE)) + for i in range(field.spatial_dimensions)] break indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)