diff --git a/python/pystencils_walberla/compat.py b/python/pystencils_walberla/compat.py index ba3b018aa95f8d94dea4f867c1d91d943421785a..2e8f242b5d948bd1af282a1ee12ac599663e70ab 100644 --- a/python/pystencils_walberla/compat.py +++ b/python/pystencils_walberla/compat.py @@ -16,6 +16,8 @@ if IS_PYSTENCILS_2: from pystencils import DEFAULTS, Target, create_type from pystencils.types import PsType, PsDereferencableType, PsCustomType from pystencils import KernelFunction + from pystencils.backend.kernelfunction import KernelParameter + from pystencils.backend.properties import FieldShape, FieldStride from pystencils.backend.emission import emit_code, CAstPrinter def get_base_type(dtype: PsType): @@ -29,9 +31,12 @@ if IS_PYSTENCILS_2: def custom_type(typename: str): return PsCustomType(typename) - + + def typestr(dtype: PsType): + return dtype.c_string() + def get_default_dtype(config): - return create_type(config.default_dtype) + return create_type(config.default_dtype) class Backend(Enum): C = auto() @@ -70,6 +75,12 @@ if IS_PYSTENCILS_2: def get_supported_instruction_sets(): return () + def param_coordinate(param: KernelParameter): + prop: FieldShape | FieldStride = param.get_properties( + (FieldShape, FieldStride) + ).pop() + return prop.coordinate + else: # pystencils 1.x @@ -78,11 +89,15 @@ else: from pystencils.typing.typed_sympy import SHAPE_DTYPE from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.enums import Backend - from pystencils.backends.cbackend import generate_c, get_headers, CustomSympyPrinter, KernelFunction - + from pystencils.backends.cbackend import generate_c, get_headers, CustomSympyPrinter + from pystencils.astnodes import KernelFunction + def custom_type(typename: str): return typename - + + def typestr(type): + return str(type) + def get_default_dtype(config): return config.data_type.default_factory() @@ -91,3 +106,6 @@ else: def target_string(target: Target) -> str: return target.name.lower() + + def param_coordinate(param: KernelFunction.Parameter): + return param.symbol.coordinate diff --git a/python/pystencils_walberla/jinja_filters.py b/python/pystencils_walberla/jinja_filters.py index a32b2dd62050b605b91ad9b6e665cfa0d1515a45..e2802e58465d2163c90e6b78beede467dd3929cf 100644 --- a/python/pystencils_walberla/jinja_filters.py +++ b/python/pystencils_walberla/jinja_filters.py @@ -7,7 +7,13 @@ except ImportError: from collections.abc import Iterable import sympy as sp -from pystencils_walberla.compat import get_base_type, generate_c, Backend, IS_PYSTENCILS_2 +from pystencils_walberla.compat import ( + get_base_type, + generate_c, + Backend, + param_coordinate, + typestr +) from pystencils import Target, TypedSymbol, Field from pystencils.field import FieldType @@ -106,7 +112,7 @@ def get_field_stride(param): assert len(additional_strides) == field.index_dimensions f_stride_name = stride_names[-1] strides.extend([f"{type_str}({e} * {f_stride_name})" for e in reversed(additional_strides)]) - return strides[param.coordinate if IS_PYSTENCILS_2 else param.symbol.coordinate] + return strides[param_coordinate(param)] def generate_declaration(kernel_info, target=Target.CPU): @@ -326,7 +332,7 @@ def generate_call(ctx, kernel, ghost_layers_to_include=0, cell_interval=None, st if param.is_field_pointer: field = param.fields[0] if field.field_type == FieldType.BUFFER: - kernel_call_lines.append(f"{param.symbol.dtype} {param.symbol.name} = {param.field_name};") + kernel_call_lines.append(f"{typestr(param.symbol.dtype)} {param.symbol.name} = {param.field_name};") else: coordinates = get_start_coordinates(field) actual_gls = f"int_c({param.field_name}->nrOfGhostLayers())" @@ -337,7 +343,7 @@ def generate_call(ctx, kernel, ghost_layers_to_include=0, cell_interval=None, st while len(coordinates) < 4: coordinates.append(0) coordinates = tuple(coordinates) - kernel_call_lines.append(f"{param.symbol.dtype} {param.symbol.name} = {param.field_name}->dataAt" + kernel_call_lines.append(f"{typestr(param.symbol.dtype)} {param.symbol.name} = {param.field_name}->dataAt" f"({coordinates[0]}, {coordinates[1]}, {coordinates[2]}, {coordinates[3]});") if assume_inner_stride_one and field.index_dimensions > 0: kernel_call_lines.append(f"WALBERLA_ASSERT_EQUAL({param.field_name}->layout(), field::fzyx)") @@ -353,7 +359,7 @@ def generate_call(ctx, kernel, ghost_layers_to_include=0, cell_interval=None, st type_str = param.symbol.dtype.c_name kernel_call_lines.append(f"const {type_str} {param.symbol.name} = {casted_stride};") elif param.is_field_shape: - coord = param.coordinate if IS_PYSTENCILS_2 else param.symbol.coordinate + coord = param_coordinate(param) field = param.fields[0] type_str = param.symbol.dtype.c_name shape = f"{type_str}({get_end_coordinates(field)[coord]})" @@ -507,7 +513,7 @@ def generate_constructor_parameters(kernel_infos, parameters_to_ignore=None): for kernel_info in kernel_infos: for param in kernel_info.parameters: if not param.is_field_parameter and param.symbol.name not in parameters_to_skip: - parameter_list.append(f"{param.symbol.dtype} {param.symbol.name}") + parameter_list.append(f"{typestr(param.symbol.dtype)} {param.symbol.name}") parameters_to_skip.append(param.symbol.name) varying_parameters = ["%s %s" % e for e in varying_parameters] @@ -576,9 +582,9 @@ def generate_members(ctx, kernel_infos, parameters_to_ignore=None, only_fields=F continue if not param.is_field_parameter and param.symbol.name not in params_to_skip: if parameter_registration and param.symbol.name in parameter_registration.scaling_info: - result.append(f"std::vector<{param.symbol.dtype}> {param.symbol.name}Vector;") + result.append(f"std::vector<{typestr(param.symbol.dtype)}> {param.symbol.name}Vector;") else: - result.append(f"{param.symbol.dtype} {param.symbol.name}_;") + result.append(f"{typestr(param.symbol.dtype)} {param.symbol.name}_;") params_to_skip.append(param.symbol.name) for kernel_info in kernel_infos: @@ -620,14 +626,14 @@ def generate_plain_parameter_list(ctx, kernel_info, cell_interval=None, ghost_la for param in kernel_info.parameters: if not param.is_field_parameter and param.symbol.name: - result.append(f"{param.symbol.dtype} {param.symbol.name}") + result.append(f"{typestr(param.symbol.dtype)} {param.symbol.name}") if hasattr(kernel_info, 'varying_parameters'): result.extend(["%s %s_;" % e for e in kernel_info.varying_parameters]) # TODO due to backward compatibility with high level interface spec for parameter in kernel_info.kernel_selection_tree.get_selection_parameter_list(): - result.append(f"{parameter.dtype} {parameter.name}") + result.append(f"{typestr(parameter.dtype)} {parameter.name}") if cell_interval: result.append(f"const CellInterval & {cell_interval}") @@ -755,7 +761,7 @@ def type_identifier_list(nested_arg_list): if isinstance(s, str) and len(s) > 0: result.append(s) elif isinstance(s, TypedSymbol): - result.append(f"{s.dtype} {s.name}") + result.append(f"{typestr(s.dtype)} {s.name}") else: recursive_flatten(s) diff --git a/python/pystencils_walberla/kernel_info.py b/python/pystencils_walberla/kernel_info.py index 09a8af5e27195e3928c6a07ab9e4fd38e32cfd3c..cca65b0abd9e595e6d20c7b990a5a36f7aa5d851 100644 --- a/python/pystencils_walberla/kernel_info.py +++ b/python/pystencils_walberla/kernel_info.py @@ -3,7 +3,7 @@ from functools import reduce from pystencils import Target, TypedSymbol from pystencils_walberla.utility import merge_sorted_lists -from pystencils_walberla.compat import backend_printer, SHAPE_DTYPE, KernelFunction, IS_PYSTENCILS_2 +from pystencils_walberla.compat import backend_printer, SHAPE_DTYPE, KernelFunction, IS_PYSTENCILS_2, param_coordinate # TODO KernelInfo and KernelFamily should have same interface @@ -47,8 +47,9 @@ class KernelInfo: spatial_shape_symbols = kwargs.get('spatial_shape_symbols', ()) if not spatial_shape_symbols: - spatial_shape_symbols = [p.symbol for p in ast_params if p.is_field_shape] - spatial_shape_symbols.sort(key=lambda e: e.coordinate) + spatial_shape_params = [p for p in ast_params if p.is_field_shape] + spatial_shape_params.sort(key=lambda e: param_coordinate(e)) + spatial_shape_symbols = [e.symbol for e in spatial_shape_params] else: spatial_shape_symbols = [TypedSymbol(s, SHAPE_DTYPE) for s in spatial_shape_symbols] diff --git a/python/pystencils_walberla/kernel_selection.py b/python/pystencils_walberla/kernel_selection.py index 6fa95460dc017a885c4979e465b2f39af402b2f8..44d8713f0b625e64b60a46467ec786db582b2e4a 100644 --- a/python/pystencils_walberla/kernel_selection.py +++ b/python/pystencils_walberla/kernel_selection.py @@ -6,7 +6,7 @@ from jinja2.filters import do_indent from pystencils import Target, TypedSymbol from pystencils_walberla.utility import merge_lists_of_symbols, merge_sorted_lists -from pystencils_walberla.compat import backend_printer, get_headers, SHAPE_DTYPE, IS_PYSTENCILS_2 +from pystencils_walberla.compat import backend_printer, get_headers, SHAPE_DTYPE, IS_PYSTENCILS_2, param_coordinate """ @@ -182,8 +182,9 @@ class KernelCallNode(AbstractKernelSelectionNode): spatial_shape_symbols = kwargs.get('spatial_shape_symbols', ()) if not spatial_shape_symbols: - spatial_shape_symbols = [p.symbol for p in ast_params if p.is_field_shape] - spatial_shape_symbols.sort(key=lambda e: e.coordinate) + spatial_shape_params = [p for p in ast_params if p.is_field_shape] + spatial_shape_params.sort(key=lambda e: param_coordinate(e)) + spatial_shape_symbols = [e.symbol for e in spatial_shape_params] else: spatial_shape_symbols = [TypedSymbol(s, SHAPE_DTYPE) for s in spatial_shape_symbols] diff --git a/python/pystencils_walberla/utility.py b/python/pystencils_walberla/utility.py index d7f9ef36bf1a44d55b3166c5aaee638f73092f2e..6ff601f0b249eda78b2287b94a98891cc3178734 100644 --- a/python/pystencils_walberla/utility.py +++ b/python/pystencils_walberla/utility.py @@ -12,7 +12,7 @@ from pystencils_walberla.compat import get_supported_instruction_sets, BasicType from lbmpy import LBStencil from pystencils_walberla.cmake_integration import CodeGenerationContext -from pystencils_walberla.compat import PS_VERSION +from pystencils_walberla.compat import PS_VERSION, get_base_type HEADER_EXTENSIONS = {'.h', '.hpp'} @@ -195,9 +195,9 @@ def struct_from_numpy_dtype(struct_name, numpy_dtype): constructor_initializer_list = [] for name, (sub_type, offset) in numpy_dtype.fields.items(): pystencils_type = create_type(sub_type) - result += f" {pystencils_type} {name};\n" + result += f" {pystencils_type.c_name} {name};\n" if name in boundary_index_array_coordinate_names or name == direction_member_name: - constructor_params.append(f"{pystencils_type} {name}_") + constructor_params.append(f"{pystencils_type.c_name} {name}_") constructor_initializer_list.append(f"{name}({name}_)") else: constructor_initializer_list.append(f"{name}()") @@ -249,6 +249,6 @@ def _field_inclusion_code(field_typedefs): f_size = field.values_per_cell() dtype = get_base_type(field.dtype) headers.add("field/GhostLayerField.h") - typedefs[typename] = f"walberla::field::GhostLayerField<{dtype}, {f_size}>" + typedefs[typename] = f"walberla::field::GhostLayerField<{dtype.c_name}, {f_size}>" return headers, typedefs