Skip to content
Snippets Groups Projects
Commit cd1f52e1 authored by Martin Bauer's avatar Martin Bauer
Browse files

walberla integration + bugfix in GPU block indexing

parent 0f21a28f
No related branches found
No related tags found
No related merge requests found
import sympy as sp import sympy as sp
import jinja2 import jinja2
import copy import copy
from pystencils import TypedSymbol
from pystencils.astnodes import ResolvedFieldAccess from pystencils.astnodes import ResolvedFieldAccess
from pystencils.data_types import get_base_type from pystencils.data_types import get_base_type
from pystencils.backends.cbackend import generate_c, CustomSympyPrinter from pystencils.backends.cbackend import generate_c, CustomSympyPrinter
...@@ -291,7 +293,7 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non ...@@ -291,7 +293,7 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non
spatial_shape_symbols = [] spatial_shape_symbols = []
for param in ast_params: for param in ast_params:
if param.is_field_shape_argument: if param.is_field_shape_argument:
spatial_shape_symbols = [sp.Symbol("%s_cpu[%d]" % (param.name, i)) spatial_shape_symbols = [TypedSymbol("%s_cpu[%d]" % (param.name, i), get_base_type(Field.SHAPE_DTYPE))
for i in range(field.spatial_dimensions)] for i in range(field.spatial_dimensions)]
if not spatial_shape_symbols: if not spatial_shape_symbols:
...@@ -301,7 +303,8 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non ...@@ -301,7 +303,8 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non
field = fields[param.field_name] field = fields[param.field_name]
if field.field_type == FieldType.GENERIC: if field.field_type == FieldType.GENERIC:
kernel_call_lines += create_field_shape_code(field, '_size', gpu_copy=False) kernel_call_lines += create_field_shape_code(field, '_size', gpu_copy=False)
spatial_shape_symbols = [sp.Symbol("_size[%d]" % (i, )) for i in range(field.spatial_dimensions)] spatial_shape_symbols = [TypedSymbol("_size[%d]" % (i, ), get_base_type(Field.SHAPE_DTYPE))
for i in range(field.spatial_dimensions)]
break break
indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols) indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment