From 7f6f3a573c5562d5829627631a21b1c443563cc6 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 22 Oct 2024 16:25:58 +0200
Subject: [PATCH] Adapt to changes from pycodegen/pystencils!420

---
 src/lbmpy/advanced_streaming/indexing.py   |  4 ++--
 src/lbmpy/boundaries/boundaryconditions.py | 24 ++++++++++++----------
 src/lbmpy/custom_code_nodes.py             |  2 +-
 src/lbmpy/lookup_tables.py                 | 19 +++++++++--------
 4 files changed, 27 insertions(+), 22 deletions(-)

diff --git a/src/lbmpy/advanced_streaming/indexing.py b/src/lbmpy/advanced_streaming/indexing.py
index fd393fb4..0999bba6 100644
--- a/src/lbmpy/advanced_streaming/indexing.py
+++ b/src/lbmpy/advanced_streaming/indexing.py
@@ -72,7 +72,7 @@ class BetweenTimestepsIndexing:
         inv = '_inv' if inverse else ''
         name = f"f_{f_dir}{inv}_dir_idx"
         if IS_PYSTENCILS_2:
-            return TypedSymbol(name, Arr(self._index_dtype))
+            return TypedSymbol(name, Arr(self._index_dtype, self._q))
         else:
             return TypedSymbol(name, self._index_dtype)
 
@@ -82,7 +82,7 @@ class BetweenTimestepsIndexing:
         name_base = f"f_{f_dir}{inv}_offsets_"
 
         if IS_PYSTENCILS_2:
-            symbols = [TypedSymbol(name_base + d, Arr(self._index_dtype)) for d in self._coordinate_names]
+            symbols = [TypedSymbol(name_base + d, Arr(self._index_dtype, self._q)) for d in self._coordinate_names]
         else:
             symbols = [TypedSymbol(name_base + d, self._index_dtype) for d in self._coordinate_names]
         
diff --git a/src/lbmpy/boundaries/boundaryconditions.py b/src/lbmpy/boundaries/boundaryconditions.py
index 420e4823..537c0bd6 100644
--- a/src/lbmpy/boundaries/boundaryconditions.py
+++ b/src/lbmpy/boundaries/boundaryconditions.py
@@ -224,12 +224,13 @@ class QuadraticBounceBack(LbBoundary):
         self.init_wall_distance = init_wall_distance
         self.equilibrium_values_name = "f_eq"
 
+        super(QuadraticBounceBack, self).__init__(name)
+
+    def inv_dir_symbol(self, stencil):
         if IS_PYSTENCILS_2:
-            self.inv_dir_symbol = TypedSymbol("inv_dir", Arr(create_type("int32")))
+            return TypedSymbol("inv_dir", Arr(create_type("int32"), stencil.Q))
         else:
-            self.inv_dir_symbol = TypedSymbol("inv_dir", create_type("int32"))
-
-        super(QuadraticBounceBack, self).__init__(name)
+            return TypedSymbol("inv_dir", create_type("int32"))
 
     @property
     def additional_data(self):
@@ -263,11 +264,12 @@ class QuadraticBounceBack(LbBoundary):
         inv_directions = [str(stencil.index(inverse_direction(direction))) for direction in stencil]
         
         if IS_PYSTENCILS_2:
-            inverse_dir_node = TranslationArraysNode([(self.inv_dir_symbol, inv_directions), ])
+            inverse_dir_node = TranslationArraysNode([(self.inv_dir_symbol(stencil), inv_directions), ])
         else:
-            dtype = self.inv_dir_symbol.dtype
-            name = self.inv_dir_symbol.name
-            inverse_dir_node = TranslationArraysNode([(dtype, name, inv_directions), ], {self.inv_dir_symbol})
+            inv_dir_symbol = self.inv_dir_symbol(stencil)
+            dtype = inv_dir_symbol.dtype
+            name = inv_dir_symbol.name
+            inverse_dir_node = TranslationArraysNode([(dtype, name, inv_directions), ], {inv_dir_symbol})
         
         return [LbmWeightInfo(lb_method, self.data_type), inverse_dir_node, NeighbourOffsetArrays(lb_method.stencil)]
 
@@ -286,7 +288,7 @@ class QuadraticBounceBack(LbBoundary):
 
     def __call__(self, f_out, f_in, dir_symbol, inv_dir, lb_method, index_field):
         omega = self.relaxation_rate
-        inv = sp.IndexedBase(self.inv_dir_symbol, shape=(1,))[dir_symbol]
+        inv = sp.IndexedBase(self.inv_dir_symbol(lb_method.stencil), shape=(1,))[dir_symbol]
         weight_info = LbmWeightInfo(lb_method, data_type=self.data_type)
         weight_of_direction = weight_info.weight_of_direction
         pdf_field_accesses = [f_out(i) for i in range(len(lb_method.stencil))]
@@ -461,7 +463,7 @@ class FreeSlip(LbBoundary):
         neighbor_offset = NeighbourOffsetArrays.neighbour_offset(dir_symbol, lb_method.stencil)
         if self.normal_direction:
             tangential_offset = tuple(offset + normal for offset, normal in zip(neighbor_offset, self.normal_direction))
-            mirrored_stencil_symbol = MirroredStencilDirections._mirrored_symbol(self.mirror_axis)
+            mirrored_stencil_symbol = MirroredStencilDirections._mirrored_symbol(self.mirror_axis, self.stencil)
             mirrored_direction = inv_dir[sp.IndexedBase(mirrored_stencil_symbol, shape=(1,))[dir_symbol]]
         else:
             normal_direction = list()
@@ -602,7 +604,7 @@ class WallFunctionBounce(LbBoundary):
         # neighbour offset symbols are basically the stencil directions defined in stencils.py:L130ff.
         neighbor_offset = NeighbourOffsetArrays.neighbour_offset(dir_symbol, lb_method.stencil)
         tangential_offset = tuple(offset + normal for offset, normal in zip(neighbor_offset, self.normal_direction))
-        mirrored_stencil_symbol = MirroredStencilDirections._mirrored_symbol(self.mirror_axis)
+        mirrored_stencil_symbol = MirroredStencilDirections._mirrored_symbol(self.mirror_axis, self.stencil)
         mirrored_direction = inv_dir[sp.IndexedBase(mirrored_stencil_symbol, shape=(1,))[dir_symbol]]
 
         name_base = "f_in_inv_offsets_"
diff --git a/src/lbmpy/custom_code_nodes.py b/src/lbmpy/custom_code_nodes.py
index 6a765db9..cf90a26b 100644
--- a/src/lbmpy/custom_code_nodes.py
+++ b/src/lbmpy/custom_code_nodes.py
@@ -49,7 +49,7 @@ class MirroredStencilDirections(CustomCodeNode):
         return tuple(direction)
 
     @staticmethod
-    def _mirrored_symbol(mirror_axis):
+    def _mirrored_symbol(mirror_axis, _stencil):
         axis = ['x', 'y', 'z']
         return TypedSymbol(f"{axis[mirror_axis]}_axis_mirrored_stencil_dir", create_type('int32'))
 
diff --git a/src/lbmpy/lookup_tables.py b/src/lbmpy/lookup_tables.py
index 5fb03114..50b2d4da 100644
--- a/src/lbmpy/lookup_tables.py
+++ b/src/lbmpy/lookup_tables.py
@@ -30,14 +30,16 @@ class NeighbourOffsetArrays(LookupTables):
             return tuple(
                 [
                     sp.IndexedBase(symbol, shape=(1,))[dir_idx]
-                    for symbol in NeighbourOffsetArrays._offset_symbols(len(stencil[0]))
+                    for symbol in NeighbourOffsetArrays._offset_symbols(stencil)
                 ]
             )
 
     @staticmethod
-    def _offset_symbols(dim):
+    def _offset_symbols(stencil):
+        q = len(stencil)
+        dim = len(stencil[0])
         return [
-            TypedSymbol(f"neighbour_offset_{d}", Arr(create_type("int32")))
+            TypedSymbol(f"neighbour_offset_{d}", Arr(create_type("int32"), q))
             for d in ["x", "y", "z"][:dim]
         ]
 
@@ -49,7 +51,7 @@ class NeighbourOffsetArrays(LookupTables):
         self._dim = len(stencil[0])
 
     def get_array_declarations(self) -> list[Assignment]:
-        array_symbols = NeighbourOffsetArrays._offset_symbols(self._dim)
+        array_symbols = NeighbourOffsetArrays._offset_symbols(self._stencil)
         return [
             Assignment(arrsymb, tuple((d[i] for d in self._stencil)))
             for i, arrsymb in enumerate(array_symbols)
@@ -69,17 +71,18 @@ class MirroredStencilDirections(LookupTables):
         return tuple(direction)
 
     @staticmethod
-    def _mirrored_symbol(mirror_axis):
+    def _mirrored_symbol(mirror_axis, stencil):
         axis = ["x", "y", "z"]
+        q = len(stencil)
         return TypedSymbol(
-            f"{axis[mirror_axis]}_axis_mirrored_stencil_dir", Arr(create_type("int32"))
+            f"{axis[mirror_axis]}_axis_mirrored_stencil_dir", Arr(create_type("int32"), q)
         )
 
     def __init__(self, stencil, mirror_axis, dtype=np.int32):
         self._offsets_dtype = create_type(dtype)  # TODO: Currently, this has no effect
 
         self._mirrored_stencil_symbol = MirroredStencilDirections._mirrored_symbol(
-            mirror_axis
+            mirror_axis, stencil
         )
         self._mirrored_directions = tuple(
             stencil.index(
@@ -94,8 +97,8 @@ class MirroredStencilDirections(LookupTables):
 
 class LbmWeightInfo(LookupTables):
     def __init__(self, lb_method, data_type="double"):
-        self._weights_array = TypedSymbol("weights", Arr(create_type(data_type)))
         self._weights = lb_method.weights
+        self._weights_array = TypedSymbol("weights", Arr(create_type(data_type), len(self._weights)))
 
     def weight_of_direction(self, dir_idx, lb_method=None):
         if isinstance(sp.sympify(dir_idx), sp.Integer):
-- 
GitLab