From ef5a7e6d51da0050b5c706fb2601e1aedaac1bc8 Mon Sep 17 00:00:00 2001
From: Ponsuganth Ilangovan Ponkumar Ilango
 <pponkumar@geophysik.uni-muenchen.de>
Date: Tue, 5 Nov 2024 11:06:39 +0100
Subject: [PATCH] Fix the value comparison test by evaluating the symbols

---
 hog/operator_generation/indexing.py           | 111 +++++++++++++-----
 .../operator_generation/test_indexing.py      |   6 +-
 2 files changed, 87 insertions(+), 30 deletions(-)

diff --git a/hog/operator_generation/indexing.py b/hog/operator_generation/indexing.py
index 2196b7d..896bcae 100644
--- a/hog/operator_generation/indexing.py
+++ b/hog/operator_generation/indexing.py
@@ -150,7 +150,10 @@ def generalized_macro_cell_index(
 def num_microvertices_per_face_from_width(width: int) -> int:
     """Computes the number of microvertices in a refined macro triangle. width depends on the level and quantifies the amount of primitives in one direction of the refined triangle."""
 
-    return int_div((width * (width + 1)), 2)
+    if USE_SYMPY_INT_DIV:
+        return sympy_int_div((width * (width + 1)), 2)
+    else:
+        return int_div((width * (width + 1)), 2)
 
 
 def num_microvertices_per_cell_from_width(width: int) -> int:
@@ -197,7 +200,9 @@ def num_faces_per_row_by_type(
 
 
 def num_cells_per_row_by_type(
-    level: int, cellType: Union[None, EdgeType, FaceType, CellType], num_microedges_per_edge: sp.Symbol
+    level: int,
+    cellType: Union[None, EdgeType, FaceType, CellType],
+    num_microedges_per_edge: sp.Symbol,
 ) -> int:
     if cellType == CellType.WHITE_UP:
         return num_microedges_per_edge
@@ -221,7 +226,9 @@ def num_micro_faces_per_macro_face(level: int, faceType: FaceType) -> int:
     )
 
 
-def num_micro_cells_per_macro_cell(level: int, cellType: CellType, num_microedges_per_edge: sp.Symbol) -> int:
+def num_micro_cells_per_macro_cell(
+    level: int, cellType: CellType, num_microedges_per_edge: sp.Symbol
+) -> int:
     return num_microvertices_per_cell_from_width(
         num_cells_per_row_by_type(level, cellType, num_microedges_per_edge)
     )
@@ -255,7 +262,7 @@ def facedof_index(
     index: Tuple[int, int, int],
     faceType: Union[None, EdgeType, FaceType, CellType],
     num_microfaces_per_face: sp.Symbol,
-    num_microedges_per_edge: sp.Symbol
+    num_microedges_per_edge: sp.Symbol,
 ) -> int:
     """Indexes triangles/faces. Used to compute offsets in volume dof indexing in 2D and AoS layout."""
     x, y, _ = index
@@ -263,7 +270,9 @@ def facedof_index(
     if faceType == FaceType.GRAY:
         return linear_macro_face_index(num_microedges_per_edge, x, y)
     elif faceType == FaceType.BLUE:
-        return num_microvertices_per_face_from_width(num_microedges_per_edge) + linear_macro_face_index(num_microedges_per_edge - 1, x, y)
+        return num_microvertices_per_face_from_width(
+            num_microedges_per_edge
+        ) + linear_macro_face_index(num_microedges_per_edge - 1, x, y)
     else:
         raise HOGException(f"Unexpected face type: {faceType}")
 
@@ -275,8 +284,10 @@ def celldof_index(
     num_microedges_per_edge: sp.Symbol,
 ) -> int:
     """Indexes cells/tetrahedra. Used to compute offsets in volume dof indexing in 3D and AoS layout."""
-    x, y, z = index 
-    width = num_cells_per_row_by_type(level, cellType, num_microedges_per_edge)  # gives expr(level)
+    x, y, z = index
+    width = num_cells_per_row_by_type(
+        level, cellType, num_microedges_per_edge
+    )  # gives expr(level)
     if cellType == CellType.WHITE_UP:
         return linear_macro_cell_index(width, x, y, z)
     elif cellType == CellType.BLUE_UP:
@@ -285,32 +296,60 @@ def celldof_index(
         ) + linear_macro_cell_index(width, x, y, z)
     elif cellType == CellType.GREEN_UP:
         return (
-            num_micro_cells_per_macro_cell(level, CellType.WHITE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.BLUE_UP, num_microedges_per_edge)
+            num_micro_cells_per_macro_cell(
+                level, CellType.WHITE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.BLUE_UP, num_microedges_per_edge
+            )
             + linear_macro_cell_index(width, x, y, z)
         )
     elif cellType == CellType.WHITE_DOWN:
         return (
-            num_micro_cells_per_macro_cell(level, CellType.WHITE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.BLUE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.GREEN_UP, num_microedges_per_edge)
+            num_micro_cells_per_macro_cell(
+                level, CellType.WHITE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.BLUE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.GREEN_UP, num_microedges_per_edge
+            )
             + linear_macro_cell_index(width, x, y, z)
         )
     elif cellType == CellType.BLUE_DOWN:
         return (
-            num_micro_cells_per_macro_cell(level, CellType.WHITE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.BLUE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.GREEN_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.WHITE_DOWN, num_microedges_per_edge)
+            num_micro_cells_per_macro_cell(
+                level, CellType.WHITE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.BLUE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.GREEN_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.WHITE_DOWN, num_microedges_per_edge
+            )
             + linear_macro_cell_index(width, x, y, z)
         )
     elif cellType == CellType.GREEN_DOWN:
         return (
-            num_micro_cells_per_macro_cell(level, CellType.WHITE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.BLUE_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.GREEN_UP, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.WHITE_DOWN, num_microedges_per_edge)
-            + num_micro_cells_per_macro_cell(level, CellType.BLUE_DOWN, num_microedges_per_edge)
+            num_micro_cells_per_macro_cell(
+                level, CellType.WHITE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.BLUE_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.GREEN_UP, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.WHITE_DOWN, num_microedges_per_edge
+            )
+            + num_micro_cells_per_macro_cell(
+                level, CellType.BLUE_DOWN, num_microedges_per_edge
+            )
             + linear_macro_cell_index(width, x, y, z)
         )
     else:
@@ -387,7 +426,7 @@ class DoFIndex:
 
     def array_index(
         self, geometry: ElementGeometry, indexing_info: IndexingInfo
-    ) -> int:
+    ) -> int | sp.Symbol:
         """
         Computes the array index of the passed DoF.
         """
@@ -404,11 +443,18 @@ class DoFIndex:
         elif self.dof_type == DoFType.EDGE:
             width = indexing_info.micro_edges_per_macro_edge
             if isinstance(geometry, TriangleElement):
-                micro_edges_one_type_per_macro_face = int_div(
-                    (indexing_info.micro_edges_per_macro_edge + 1)
-                    * indexing_info.micro_edges_per_macro_edge,
-                    2,
-                )
+                if USE_SYMPY_INT_DIV:
+                    micro_edges_one_type_per_macro_face = sympy_int_div(
+                        (indexing_info.micro_edges_per_macro_edge + 1)
+                        * indexing_info.micro_edges_per_macro_edge,
+                        2,
+                    )
+                else:
+                    micro_edges_one_type_per_macro_face = int_div(
+                        (indexing_info.micro_edges_per_macro_edge + 1)
+                        * indexing_info.micro_edges_per_macro_edge,
+                        2,
+                    )
 
                 order: List[Union[None, EdgeType, FaceType, CellType]] = [
                     EdgeType.X,
@@ -447,7 +493,11 @@ class DoFIndex:
                 numMicroVolumes = indexing_info.num_microfaces_per_face
 
                 microVolume = facedof_index(
-                    indexing_info.level, self.primitive_index, self.dof_sub_type, indexing_info.num_microfaces_per_face, indexing_info.micro_edges_per_macro_edge
+                    indexing_info.level,
+                    self.primitive_index,
+                    self.dof_sub_type,
+                    indexing_info.num_microfaces_per_face,
+                    indexing_info.micro_edges_per_macro_edge,
                 )
 
                 if self.mem_layout == VolumeDoFMemoryLayout.SoA:
@@ -463,7 +513,10 @@ class DoFIndex:
                 numMicroVolumes = indexing_info.num_microcells_per_cell
 
                 microVolume = celldof_index(
-                    indexing_info.level, self.primitive_index, self.dof_sub_type, indexing_info.micro_edges_per_macro_edge
+                    indexing_info.level,
+                    self.primitive_index,
+                    self.dof_sub_type,
+                    indexing_info.micro_edges_per_macro_edge,
                 )
 
                 if self.mem_layout == VolumeDoFMemoryLayout.SoA:
diff --git a/hog_tests/operator_generation/test_indexing.py b/hog_tests/operator_generation/test_indexing.py
index b30b371..cdebad9 100644
--- a/hog_tests/operator_generation/test_indexing.py
+++ b/hog_tests/operator_generation/test_indexing.py
@@ -37,6 +37,7 @@ from hog.operator_generation.indexing import (
     VolumeDoFMemoryLayout,
     num_microfaces_per_face,
     num_microcells_per_cell,
+    num_microedges_per_edge,
 )
 
 
@@ -277,7 +278,10 @@ def test_micro_volume_to_volume_indices():
         array_index = sp.simplify(
             dof_indices[intra_primitive_index].array_index(geometry, indexing_info)
         )
-        
+        array_index = array_index.subs(
+            [(indexing_info.micro_edges_per_macro_edge, num_microedges_per_edge(level))]
+        )
+
         assert array_index == target_array_index
 
     # 2D, P0:
-- 
GitLab