From 6df2c640b0f88cf1882604e5917ea21f40687ba9 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 11 Apr 2025 12:02:14 +0200
Subject: [PATCH 1/4] insert casts in `add_subexpressions_for_field_reads`

---
 src/pystencils/simp/simplifications.py | 37 ++++++++++++++++++++++----
 1 file changed, 32 insertions(+), 5 deletions(-)

diff --git a/src/pystencils/simp/simplifications.py b/src/pystencils/simp/simplifications.py
index 9368c8f51..baecf6cb4 100644
--- a/src/pystencils/simp/simplifications.py
+++ b/src/pystencils/simp/simplifications.py
@@ -1,13 +1,20 @@
+from __future__ import annotations
+from typing import TYPE_CHECKING
+
 from itertools import chain
 from typing import Callable, List, Sequence, Union
 from collections import defaultdict
 
 import sympy as sp
 
+from ..types import UserTypeSpec
 from ..assignment import Assignment
-from ..sympyextensions import subs_additive, is_constant, recursive_collect
+from ..sympyextensions import subs_additive, is_constant, recursive_collect, tcast
 from ..sympyextensions.typed_sympy import TypedSymbol
 
+if TYPE_CHECKING:
+    from .assignment_collection import AssignmentCollection
+
 
 # TODO rewrite with SymPy AST
 # def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
@@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac):
     return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
 
 
-def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
+def add_subexpressions_for_field_reads(
+    ac: AssignmentCollection,
+    subexpressions=True,
+    main_assignments=True,
+    data_type: UserTypeSpec | None = None
+):
     r"""Substitutes field accesses on rhs of assignments with subexpressions
 
     Can change semantics of the update rule (which is the goal of this transformation)
     This is useful if a field should be update in place - all values are loaded before into subexpression variables,
     then the new values are computed and written to the same field in-place.
     Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
-    this data type. This is useful for mixed precision kernels
+    this data type, and an explicit cast is inserted. This is useful for mixed precision kernels
     """
     field_reads = set()
     to_iterate = []
@@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
             substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
         else:
             substitutions.update({fa: lhs})
-    return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
-                                     substitute_on_lhs=False, sort_topologically=False)
+    
+    ac = ac.new_with_substitutions(
+        substitutions,
+        add_substitutions_as_subexpressions=False,
+        substitute_on_lhs=False,
+        sort_topologically=False
+    )
+
+    loads: list[Assignment] = []
+    for fa in field_reads:
+        rhs = fa if data_type is None else tcast(fa, data_type)
+        loads.append(
+            Assignment(substitutions[fa], rhs)
+        )
+
+    ac.subexpressions = loads + ac.subexpressions
+    return ac
 
 
 def transform_rhs(assignment_list, transformation, *args, **kwargs):
-- 
GitLab


From a837cb73c966026cc40da690606e9db2d12e2f2d Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 11 Apr 2025 12:03:11 +0200
Subject: [PATCH 2/4] Permit custom numerical data type in BoundaryHandling

---
 src/pystencils/boundaries/boundaryhandling.py | 20 +++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index 58340c3e0..f0a66ac84 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -4,6 +4,7 @@ import numpy as np
 import sympy as sp
 
 from pystencils import create_kernel, CreateKernelConfig, Target
+from pystencils.types import UserTypeSpec, create_numeric_type
 from pystencils.assignment import Assignment
 from pystencils.boundaries.createindexlist import (
     create_boundary_index_array, numpy_data_type_for_boundary_object)
@@ -84,13 +85,14 @@ class FlagInterface:
 class BoundaryHandling:
 
     def __init__(self, data_handling, field_name, stencil, name="boundary_handling", flag_interface=None,
-                 target: Target = Target.CPU, openmp=True):
+                 target: Target = Target.CPU, default_dtype: UserTypeSpec = "float64", openmp=True):
         assert data_handling.has_data(field_name)
         assert data_handling.dim == len(stencil[0]), "Dimension of stencil and data handling do not match"
         self._data_handling = data_handling
         self._field_name = field_name
         self._index_array_name = name + "IndexArrays"
         self._target = target
+        self._default_dtype = create_numeric_type(default_dtype)
         self._openmp = openmp
         self._boundary_object_to_boundary_info = {}
         self.stencil = stencil
@@ -313,8 +315,11 @@ class BoundaryHandling:
         return self._boundary_object_to_boundary_info[boundary_obj].flag
 
     def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj):
-        return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj,
-                                      target=self._target, cpu_openmp=self._openmp)
+        cfg = CreateKernelConfig()
+        cfg.target = self._target
+        cfg.default_dtype = self._default_dtype
+        cfg.cpu.openmp.enable = self._openmp
+        return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, cfg)
 
     def _create_index_fields(self):
         dh = self._data_handling
@@ -452,11 +457,14 @@ class BoundaryOffsetInfo:
         return sp.Symbol("invdir")
 
 
-def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
+def create_boundary_kernel(field, index_field, stencil, boundary_functor, cfg: CreateKernelConfig):
     #   TODO: reconsider how to control the index_dtype in boundary kernels
-    config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args)
+    config = cfg.copy()
+    config.index_field = index_field
+    idx_dtype = SInt(32)
+    config.index_dtype = idx_dtype
 
-    offset_info = BoundaryOffsetInfo(stencil, config.index_dtype)
+    offset_info = BoundaryOffsetInfo(stencil, idx_dtype)
     elements = offset_info.get_array_declarations()
     dir_symbol = TypedSymbol("dir", config.index_dtype)
     elements += [Assignment(dir_symbol, index_field[0]('dir'))]
-- 
GitLab


From dc9b711b58b93db6158567ac046af36a360b8882 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 11 Apr 2025 12:28:51 +0200
Subject: [PATCH 3/4] fix sanity check for deprecated cpu_vectorize_info

---
 src/pystencils/codegen/config.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 8e7e54ff1..295289ac0 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -682,7 +682,7 @@ class CreateKernelConfig(ConfigBase):
         if cpu_vectorize_info is not None:
             _deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize")
             if "instruction_set" in cpu_vectorize_info:
-                if self.target != Target.GenericCPU:
+                if self.target is not None and self.target != Target.GenericCPU:
                     raise ValueError(
                         "Setting 'instruction_set' in the deprecated 'cpu_vectorize_info' option is only "
                         "valid if `target == Target.CPU`."
-- 
GitLab


From 94657080402f4e223c89e5cde2f6ab6a9e225b4a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 11 Apr 2025 18:04:27 +0200
Subject: [PATCH 4/4] Add test cases

---
 tests/frontend/test_simplifications.py |  2 ++
 tests/runtime/test_boundary.py         | 10 ++++++----
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/tests/frontend/test_simplifications.py b/tests/frontend/test_simplifications.py
index 45cde7241..771f82159 100644
--- a/tests/frontend/test_simplifications.py
+++ b/tests/frontend/test_simplifications.py
@@ -147,6 +147,8 @@ def test_add_subexpressions_for_field_reads():
     assert len(ac3.subexpressions) == 2
     assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol)
     assert ac3.subexpressions[0].lhs.dtype == create_type("float32")
+    assert isinstance(ac3.subexpressions[0].rhs, ps.tcast)
+    assert ac3.subexpressions[0].rhs.dtype == create_type("float32")
 
 
 #   TODO: What does this test mean to accomplish?
diff --git a/tests/runtime/test_boundary.py b/tests/runtime/test_boundary.py
index 226510b83..422553bca 100644
--- a/tests/runtime/test_boundary.py
+++ b/tests/runtime/test_boundary.py
@@ -222,15 +222,17 @@ def test_boundary_data_setter():
         assert np.all(data_setter.link_positions(1) == 6.)
 
 
+@pytest.mark.parametrize("dtype", [np.float32, np.float64])
 @pytest.mark.parametrize('with_indices', ('with_indices', False))
-def test_dirichlet(with_indices):
+def test_dirichlet(dtype, with_indices):
     value = (1, 20, 3) if with_indices else 1
 
     dh = SerialDataHandling(domain_size=(7, 7))
-    src = dh.add_array('src', values_per_cell=3 if with_indices else 1)
-    dh.cpu_arrays.src[...] = np.random.rand(*src.shape)
+    src = dh.add_array('src', values_per_cell=3 if with_indices else 1, dtype=dtype)
+    rng = np.random.default_rng()
+    dh.cpu_arrays.src[...] = rng.random(src.shape, dtype=dtype)
     boundary_stencil = [(1, 0), (-1, 0), (0, 1), (0, -1)]
-    boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil)
+    boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil, default_dtype=dtype)
     dirichlet = Dirichlet(value)
     assert dirichlet.name == 'Dirichlet'
     dirichlet.name = "wall"
-- 
GitLab