diff --git a/pystencils_walberla/jinja_filters.py b/pystencils_walberla/jinja_filters.py index 9ce43a848b843d06d675c24cbb57fbf352a31817..7cc2d25369ce260e4c4ca34bfde7f47f25077ef8 100644 --- a/pystencils_walberla/jinja_filters.py +++ b/pystencils_walberla/jinja_filters.py @@ -2,7 +2,8 @@ import jinja2 import sympy as sp from pystencils import TypedSymbol -from pystencils.backends.cbackend import CustomSympyPrinter, generate_c +from pystencils.backends.cbackend import generate_c +from pystencils.backends.cuda_backend import CudaSympyPrinter from pystencils.data_types import get_base_type from pystencils.field import FieldType from pystencils.kernelparameters import SHAPE_DTYPE @@ -263,7 +264,7 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non spatial_shape_symbols = [TypedSymbol(s, SHAPE_DTYPE) for s in spatial_shape_symbols] indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols) - sp_printer_c = CustomSympyPrinter(dialect='cuda') + sp_printer_c = CudaSympyPrinter() kernel_call_lines += [ "dim3 _block(int(%s), int(%s), int(%s));" % tuple(sp_printer_c.doprint(e) for e in indexing_dict['block']), "dim3 _grid(int(%s), int(%s), int(%s));" % tuple(sp_printer_c.doprint(e) for e in indexing_dict['grid']),