diff --git a/pystencils_walberla/jinja_filters.py b/pystencils_walberla/jinja_filters.py index 2709f8a5b2e2c99dff082095865a8696bef97a13..21611639a0a3f0771ad1375a82b35ac8b02f7101 100644 --- a/pystencils_walberla/jinja_filters.py +++ b/pystencils_walberla/jinja_filters.py @@ -1,3 +1,5 @@ +import itertools + import jinja2 import sympy as sp @@ -8,7 +10,10 @@ from pystencils.data_types import get_base_type from pystencils.field import FieldType from pystencils.kernelparameters import SHAPE_DTYPE from pystencils.sympyextensions import prod -from pystencils_walberla.special_symbols import AABBMin, CellIntervallMin +from pystencils_walberla.special_symbols import ( + AABBMax, AABBMin, Dx, aabb_max_vector, aabb_min_vector, dx_vector) + +SPECIAL_SYMBOL_NAMES = [s.name for s in itertools.chain(aabb_min_vector, aabb_max_vector, dx_vector)] temporary_fieldMemberTemplate = """ private: std::set< {type} *, field::SwapableCompare< {type} * > > cache_{original_field_name}_;""" @@ -134,11 +139,11 @@ def field_extraction_code(field, is_temporary, declaration_only=False, @jinja2.contextfilter -def generate_block_data_to_field_extraction(ctx, kernel_info, parameters_to_ignore=(), parameters=None, +def generate_block_data_to_field_extraction(ctx, kernel_info, parameters_to_ignore=[], parameters=None, declarations_only=False, no_declarations=False): """Generates code that extracts all required fields of a kernel from a walberla block storage.""" + parameters_to_ignore = itertools.chain(parameters_to_ignore, SPECIAL_SYMBOL_NAMES) if parameters is not None: - assert parameters_to_ignore == () field_parameters = [] for param in kernel_info.parameters: if param.is_field_pointer and param.field_name in parameters: @@ -235,7 +240,7 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non coordinates = get_start_coordinates(field) actual_gls = "int_c(%s->nrOfGhostLayers())" % (param.field_name, ) coord_set = set(coordinates) - coord_set = sorted(coord_set, key=lambda e: str(e)) + coord_set = sorted(coord_set, key=str) for c in coord_set: kernel_call_lines.append("WALBERLA_ASSERT_GREATER_EQUAL(%s, -%s);" % (c, actual_gls)) @@ -257,24 +262,25 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non max_value = "%s->%sSizeWithGhostLayer()" % (field.name, ('x', 'y', 'z')[coord]) kernel_call_lines.append("WALBERLA_ASSERT_GREATER_EQUAL(%s, %s);" % (max_value, shape)) kernel_call_lines.append("const %s %s = %s;" % (type_str, param.symbol.name, shape)) - elif isinstance(param.symbol, AABBMin): + elif isinstance(param.symbol, (AABBMin, AABBMax)): type_str = param.symbol.dtype - dim_letter = {0: 'X', 1: 'Y', 2: 'Z'}[param.symbol.dim] - code = f'block->getAABB.getMin{dim_letter}()' - kernel_call_lines.append("const %s %s = %s;" % (type_str, param.symbol.name, code)) - elif isinstance(param.symbol, CellIntervallMin): - type_str = param.symbol.dtype - if cell_interval is None: - get_global_ci = 'CellInterval global_ci = dynamic_cast< StructuredBlockStorage* >(&block->getBlockStorage())->getBlockCellBB(*block);' # noqa - else: - get_global_ci = f'CellInterval global_ci;\n'\ - f'blocks->transformBlockLocalToGlobalCellInterval(global_ci, *block, {cell_interval});' - if get_global_ci not in kernel_call_lines: - kernel_call_lines.append(get_global_ci) dim_letter = {0: 'x', 1: 'y', 2: 'z'}[param.symbol.dim] - code = f'global_ci.{dim_letter}Min()' + code = f'block->getAABB().{dim_letter}{"Min" if isinstance(param.symbol, AABBMin) else "Max"}()' kernel_call_lines.append("const %s %s = %s;" % (type_str, param.symbol.name, code)) + elif isinstance(param.symbol, Dx): + # Print those later, when field shapes are printed + pass + + for param in filter(lambda p: isinstance(p.symbol, Dx), ast_params): + type_str = param.symbol.dtype + coord = param.symbol.dim + + shape_param = next(filter(lambda x: x.is_field_shape and x.symbol.coordinate == coord, ast_params)) + + dim_letter = {0: 'x', 1: 'y', 2: 'z'}[param.symbol.dim] + code = f'block->getAABB().{dim_letter}Size() / static_cast<real_t>({shape_param})' + kernel_call_lines.append("const %s %s = %s;" % (type_str, param.symbol.name, code)) call_parameters = ", ".join([p.symbol.name for p in ast_params]) if not is_cpu: @@ -284,8 +290,8 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non else: spatial_shape_symbols = [TypedSymbol(s, SHAPE_DTYPE) for s in spatial_shape_symbols] - assert spatial_shape_symbols, "No shape parameters in kernel function arguments.\n"\ - "Please be only use kernels for generic field sizes!" + assert spatial_shape_symbols, "No shape parameters in kernel function arguments.\n" + "Please be only use kernels for generic field sizes!" indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols) sp_printer_c = CudaSympyPrinter() @@ -313,6 +319,7 @@ def generate_constructor_initializer_list(kernel_info, parameters_to_ignore=None parameters_to_ignore = [] parameters_to_ignore += kernel_info.temporary_fields + parameters_to_ignore += tuple(SPECIAL_SYMBOL_NAMES) parameter_initializer_list = [] for param in kernel_info.parameters: @@ -327,6 +334,8 @@ def generate_constructor_parameters(kernel_info, parameters_to_ignore=None): if parameters_to_ignore is None: parameters_to_ignore = [] + parameters_to_ignore = list(itertools.chain(parameters_to_ignore, SPECIAL_SYMBOL_NAMES)) + varying_parameters = [] if hasattr(kernel_info, 'varying_parameters'): varying_parameters = kernel_info.varying_parameters @@ -344,10 +353,13 @@ def generate_constructor_parameters(kernel_info, parameters_to_ignore=None): @jinja2.contextfilter -def generate_members(ctx, kernel_info, parameters_to_ignore=(), only_fields=False): +def generate_members(ctx, kernel_info, parameters_to_ignore=[], only_fields=False): + ast = kernel_info.ast fields = {f.name: f for f in ast.fields_accessed} + parameters_to_ignore = itertools.chain(parameters_to_ignore, SPECIAL_SYMBOL_NAMES) + params_to_skip = tuple(parameters_to_ignore) + tuple(kernel_info.temporary_fields) params_to_skip += tuple(e[1] for e in kernel_info.varying_parameters) is_gpu = ctx['target'] == 'gpu' diff --git a/pystencils_walberla/special_symbols.py b/pystencils_walberla/special_symbols.py index 4461ef3187d393aa31988c3f2a8b7251a538da6d..e2672df28dcd895ec9363133b4b5a8c0fb324933 100644 --- a/pystencils_walberla/special_symbols.py +++ b/pystencils_walberla/special_symbols.py @@ -8,20 +8,20 @@ import pystencils from pystencils.data_types import TypedSymbol, create_type -class CellIntervallMin(TypedSymbol): +class AABBMax(TypedSymbol): """ Local cell interval in global index coordinates """ def __new__(cls, *args, **kwds): - obj = CellIntervallMin.__xnew_cached_(cls, *args, **kwds) + obj = AABBMax.__xnew_cached_(cls, *args, **kwds) return obj def __new_stage2__(cls, dim, *args, **kwargs): - obj = super(CellIntervallMin, cls).__xnew__(cls, - f'cell_interval_min_{dim}', - create_type('int64'), - *args, - **kwargs) + obj = super(AABBMax, cls).__xnew__(cls, + f'_aabb_max_{dim}', + create_type('double'), + *args, + **kwargs) obj.dim = dim return obj @@ -29,7 +29,7 @@ class CellIntervallMin(TypedSymbol): __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) def _hashable_content(self): - return super()._hashable_content(), self.dim + return super()._hashable_content(), self.dim, self.__class__.__name__ def __getnewargs__(self): return self.dim @@ -45,8 +45,8 @@ class AABBMin(TypedSymbol): def __new_stage2__(cls, dim, *args, **kwargs): obj = super(AABBMin, cls).__xnew__(cls, - f'aabb_min_{dim}', - create_type('int64'), + f'_aabb_min_{dim}', + create_type('double'), *args, **kwargs) obj.dim = dim @@ -56,22 +56,52 @@ class AABBMin(TypedSymbol): __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) def _hashable_content(self): - return super()._hashable_content(), self.dim + return super()._hashable_content(), self.dim, self.__class__.__name__ + + def __getnewargs__(self): + return self.dim + + +class Dx(TypedSymbol): + """ + Local cell interval in global index coordinates + """ + def __new__(cls, *args, **kwds): + obj = Dx.__xnew_cached_(cls, *args, **kwds) + return obj + + def __new_stage2__(cls, dim, *args, **kwargs): + obj = super(Dx, cls).__xnew__(cls, + f'_dx_{dim}', + create_type('double'), + *args, + **kwargs) + obj.dim = dim + return obj + + __xnew__ = staticmethod(__new_stage2__) + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + def _hashable_content(self): + return super()._hashable_content(), self.dim, self.__class__.__name__ def __getnewargs__(self): return self.dim +dx_vector = sp.Matrix([Dx(i) for i in range(3)]) +dx = Dx(0) +dy = Dx(1) +dz = Dx(2) aabb_min_vector = sp.Matrix([AABBMin(i) for i in range(3)]) aabb_min_x = AABBMin(0) aabb_min_y = AABBMin(1) aabb_min_z = AABBMin(2) - -cell_interval_min_vector = sp.Matrix([CellIntervallMin(i) for i in range(3)]) -cell_interval_min_x = CellIntervallMin(0) -cell_interval_min_y = CellIntervallMin(1) -cell_interval_min_z = CellIntervallMin(2) -current_global_idx = cell_interval_min_vector + pystencils.x_vector(3) -current_global_x = cell_interval_min_x + pystencils.x_ -current_global_y = cell_interval_min_y + pystencils.y_ -current_global_z = cell_interval_min_z + pystencils.z_ +aabb_max_vector = sp.Matrix([AABBMax(i) for i in range(3)]) +aabb_max_x = AABBMax(0) +aabb_max_y = AABBMax(1) +aabb_max_z = AABBMax(2) +global_coord = aabb_min_vector + sp.Matrix([i*j for i, j in zip(dx_vector, pystencils.x_vector(3))]) +global_coord_x = global_coord[0] +global_coord_y = global_coord[1] +global_coord_z = global_coord[2]