From a2a08c23a4597fd73e0dfa02a553f81776aa4320 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 11 Apr 2025 12:15:42 +0200
Subject: [PATCH] fix data types on boundary force vector

---
 src/lbmpy/boundaries/boundaryhandling.py |  7 +++++--
 tests/test_boundary_handling.py          | 24 +++++++++++++++++++++---
 2 files changed, 26 insertions(+), 5 deletions(-)

diff --git a/src/lbmpy/boundaries/boundaryhandling.py b/src/lbmpy/boundaries/boundaryhandling.py
index 9e1e7104..5c048b37 100644
--- a/src/lbmpy/boundaries/boundaryhandling.py
+++ b/src/lbmpy/boundaries/boundaryhandling.py
@@ -12,6 +12,9 @@ from lbmpy.advanced_streaming.utility import is_inplace, Timestep, AccessPdfValu
 
 from .._compat import IS_PYSTENCILS_2
 
+if IS_PYSTENCILS_2:
+    from pystencils.types import PsNumericType
+
 
 class LatticeBoltzmannBoundaryHandling(BoundaryHandling):
     """
@@ -188,10 +191,10 @@ def create_lattice_boltzmann_boundary_kernel(pdf_field, index_field, lb_method,
             **kernel_creation_args
         )
 
-        default_data_type = config.get_option("default_dtype")
+        default_data_type: PsNumericType = config.get_option("default_dtype")
 
         if force_vector is None:
-            force_vector_type = np.dtype([(f"F_{i}", default_data_type.c_name) for i in range(dim)], align=True)
+            force_vector_type = np.dtype([(f"F_{i}", default_data_type.numpy_dtype) for i in range(dim)], align=True)
             force_vector = Field.create_generic('force_vector', spatial_dimensions=1,
                                                 dtype=force_vector_type, field_type=FieldType.INDEXED)
 
diff --git a/tests/test_boundary_handling.py b/tests/test_boundary_handling.py
index 25e14e54..d0750f96 100644
--- a/tests/test_boundary_handling.py
+++ b/tests/test_boundary_handling.py
@@ -15,6 +15,8 @@ from pystencils import create_data_handling, make_slice, Target, CreateKernelCon
 from pystencils.slicing import slice_from_direction
 from pystencils.stencil import inverse_direction
 
+from lbmpy._compat import IS_PYSTENCILS_2
+
 
 def mirror_stencil(direction, mirror_axis):
     for i, n in enumerate(mirror_axis):
@@ -448,8 +450,17 @@ def test_force_on_boundary(given_force_vector, dtype):
     method = create_lb_method(lbm_config=LBMConfig(stencil=stencil, method=Method.SRT, relaxation_rate=1.8))
 
     noslip = NoSlip(name="noslip", calculate_force_on_boundary=True)
-    bouzidi = NoSlipLinearBouzidi(name="bouzidi", calculate_force_on_boundary=True)
-    qq_bounce_Back = QuadraticBounceBack(name="qqBB", relaxation_rate=1.8, calculate_force_on_boundary=True)
+    bouzidi = NoSlipLinearBouzidi(
+        name="bouzidi",
+        data_type=dtype,
+        calculate_force_on_boundary=True
+    )
+    qq_bounce_Back = QuadraticBounceBack(
+        name="qqBB",
+        relaxation_rate=1.8,
+        data_type=dtype,
+        calculate_force_on_boundary=True
+    )
 
     boundary_objects = [noslip, bouzidi, qq_bounce_Back]
     for boundary in boundary_objects:
@@ -465,7 +476,14 @@ def test_force_on_boundary(given_force_vector, dtype):
         index_field = ps.Field('indexVector', ps.FieldType.INDEXED, index_struct_dtype, layout=[0],
                                shape=(ps.TypedSymbol("indexVectorSize", "int32"), 1), strides=(1, 1))
 
-        create_lattice_boltzmann_boundary_kernel(pdfs, index_field, method, boundary, force_vector=force_vector)
+        create_lattice_boltzmann_boundary_kernel(
+            pdfs,
+            index_field,
+            method,
+            boundary,
+            force_vector=force_vector,
+            **({"default_dtype": dtype} if IS_PYSTENCILS_2 else dict())
+        )
 
 
 def _numpy_data_type_for_boundary_object(boundary_object, dim):
-- 
GitLab