diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index 58340c3e0fbb16b98af2cf08c3d1894ca34a2309..f0a66ac840b42afe66eb37fc17ab4ec87ae16556 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -4,6 +4,7 @@ import numpy as np import sympy as sp from pystencils import create_kernel, CreateKernelConfig, Target +from pystencils.types import UserTypeSpec, create_numeric_type from pystencils.assignment import Assignment from pystencils.boundaries.createindexlist import ( create_boundary_index_array, numpy_data_type_for_boundary_object) @@ -84,13 +85,14 @@ class FlagInterface: class BoundaryHandling: def __init__(self, data_handling, field_name, stencil, name="boundary_handling", flag_interface=None, - target: Target = Target.CPU, openmp=True): + target: Target = Target.CPU, default_dtype: UserTypeSpec = "float64", openmp=True): assert data_handling.has_data(field_name) assert data_handling.dim == len(stencil[0]), "Dimension of stencil and data handling do not match" self._data_handling = data_handling self._field_name = field_name self._index_array_name = name + "IndexArrays" self._target = target + self._default_dtype = create_numeric_type(default_dtype) self._openmp = openmp self._boundary_object_to_boundary_info = {} self.stencil = stencil @@ -313,8 +315,11 @@ class BoundaryHandling: return self._boundary_object_to_boundary_info[boundary_obj].flag def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj): - return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, - target=self._target, cpu_openmp=self._openmp) + cfg = CreateKernelConfig() + cfg.target = self._target + cfg.default_dtype = self._default_dtype + cfg.cpu.openmp.enable = self._openmp + return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, cfg) def _create_index_fields(self): dh = self._data_handling @@ -452,11 +457,14 @@ class BoundaryOffsetInfo: return sp.Symbol("invdir") -def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): +def create_boundary_kernel(field, index_field, stencil, boundary_functor, cfg: CreateKernelConfig): # TODO: reconsider how to control the index_dtype in boundary kernels - config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args) + config = cfg.copy() + config.index_field = index_field + idx_dtype = SInt(32) + config.index_dtype = idx_dtype - offset_info = BoundaryOffsetInfo(stencil, config.index_dtype) + offset_info = BoundaryOffsetInfo(stencil, idx_dtype) elements = offset_info.get_array_declarations() dir_symbol = TypedSymbol("dir", config.index_dtype) elements += [Assignment(dir_symbol, index_field[0]('dir'))] diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py index 8e7e54ff1125a8bba2ba35c223277ee2867c28b7..295289ac01dd5d204101f8734fd27caf43ddebba 100644 --- a/src/pystencils/codegen/config.py +++ b/src/pystencils/codegen/config.py @@ -682,7 +682,7 @@ class CreateKernelConfig(ConfigBase): if cpu_vectorize_info is not None: _deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize") if "instruction_set" in cpu_vectorize_info: - if self.target != Target.GenericCPU: + if self.target is not None and self.target != Target.GenericCPU: raise ValueError( "Setting 'instruction_set' in the deprecated 'cpu_vectorize_info' option is only " "valid if `target == Target.CPU`." diff --git a/src/pystencils/simp/simplifications.py b/src/pystencils/simp/simplifications.py index 9368c8f51a4aabd03c15a0741db5930eb8865884..baecf6cb4118770d64a582310c3962facf95b99a 100644 --- a/src/pystencils/simp/simplifications.py +++ b/src/pystencils/simp/simplifications.py @@ -1,13 +1,20 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + from itertools import chain from typing import Callable, List, Sequence, Union from collections import defaultdict import sympy as sp +from ..types import UserTypeSpec from ..assignment import Assignment -from ..sympyextensions import subs_additive, is_constant, recursive_collect +from ..sympyextensions import subs_additive, is_constant, recursive_collect, tcast from ..sympyextensions.typed_sympy import TypedSymbol +if TYPE_CHECKING: + from .assignment_collection import AssignmentCollection + # TODO rewrite with SymPy AST # def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: @@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac): return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) -def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None): +def add_subexpressions_for_field_reads( + ac: AssignmentCollection, + subexpressions=True, + main_assignments=True, + data_type: UserTypeSpec | None = None +): r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) This is useful if a field should be update in place - all values are loaded before into subexpression variables, then the new values are computed and written to the same field in-place. Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have - this data type. This is useful for mixed precision kernels + this data type, and an explicit cast is inserted. This is useful for mixed precision kernels """ field_reads = set() to_iterate = [] @@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments substitutions.update({fa: TypedSymbol(lhs.name, data_type)}) else: substitutions.update({fa: lhs}) - return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, - substitute_on_lhs=False, sort_topologically=False) + + ac = ac.new_with_substitutions( + substitutions, + add_substitutions_as_subexpressions=False, + substitute_on_lhs=False, + sort_topologically=False + ) + + loads: list[Assignment] = [] + for fa in field_reads: + rhs = fa if data_type is None else tcast(fa, data_type) + loads.append( + Assignment(substitutions[fa], rhs) + ) + + ac.subexpressions = loads + ac.subexpressions + return ac def transform_rhs(assignment_list, transformation, *args, **kwargs): diff --git a/tests/frontend/test_simplifications.py b/tests/frontend/test_simplifications.py index 45cde724108fe7578d8ff2dc9b8a2509a9add728..771f82159630f3d96ce05298ed3e70ddea440b1b 100644 --- a/tests/frontend/test_simplifications.py +++ b/tests/frontend/test_simplifications.py @@ -147,6 +147,8 @@ def test_add_subexpressions_for_field_reads(): assert len(ac3.subexpressions) == 2 assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) assert ac3.subexpressions[0].lhs.dtype == create_type("float32") + assert isinstance(ac3.subexpressions[0].rhs, ps.tcast) + assert ac3.subexpressions[0].rhs.dtype == create_type("float32") # TODO: What does this test mean to accomplish? diff --git a/tests/runtime/test_boundary.py b/tests/runtime/test_boundary.py index 226510b83d8832a5a189552df5c8760235f0d598..422553bcafb0ca1278f70f63a725d6f1cba8f496 100644 --- a/tests/runtime/test_boundary.py +++ b/tests/runtime/test_boundary.py @@ -222,15 +222,17 @@ def test_boundary_data_setter(): assert np.all(data_setter.link_positions(1) == 6.) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize('with_indices', ('with_indices', False)) -def test_dirichlet(with_indices): +def test_dirichlet(dtype, with_indices): value = (1, 20, 3) if with_indices else 1 dh = SerialDataHandling(domain_size=(7, 7)) - src = dh.add_array('src', values_per_cell=3 if with_indices else 1) - dh.cpu_arrays.src[...] = np.random.rand(*src.shape) + src = dh.add_array('src', values_per_cell=3 if with_indices else 1, dtype=dtype) + rng = np.random.default_rng() + dh.cpu_arrays.src[...] = rng.random(src.shape, dtype=dtype) boundary_stencil = [(1, 0), (-1, 0), (0, 1), (0, -1)] - boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil) + boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil, default_dtype=dtype) dirichlet = Dirichlet(value) assert dirichlet.name == 'Dirichlet' dirichlet.name = "wall"