From 69a63b0b761bd20a73ff49e23248601805744472 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 12 Nov 2024 08:39:16 +0100
Subject: [PATCH] Refactor DEFAULTS & fix bugs concerning data types of spatial
 counter symbols

---
 .../backend/kernelcreation/iteration_space.py |  6 +-
 src/pystencils/defaults.py                    | 49 +++++++------
 src/pystencils/py.typed                       |  0
 tests/kernelcreation/test_spatial_counters.py | 70 +++++++++++++++++++
 4 files changed, 100 insertions(+), 25 deletions(-)
 create mode 100644 src/pystencils/py.typed
 create mode 100644 tests/kernelcreation/test_spatial_counters.py

diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py
index 4f057e1fc..ba77ad24d 100644
--- a/src/pystencils/backend/kernelcreation/iteration_space.py
+++ b/src/pystencils/backend/kernelcreation/iteration_space.py
@@ -13,7 +13,7 @@ from ..memory import PsSymbol, PsBuffer
 from ..constants import PsConstant
 from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
 from ..ast.util import failing_cast
-from ...types import PsStructType, constify
+from ...types import PsStructType
 from ..exceptions import PsInputError, KernelConstraintsError
 
 if TYPE_CHECKING:
@@ -359,7 +359,7 @@ def create_sparse_iteration_space(
     dim = archetype_field.spatial_dimensions
     coord_members = [
         PsStructType.Member(name, ctx.index_dtype)
-        for name in DEFAULTS._index_struct_coordinate_names[:dim]
+        for name in DEFAULTS.index_struct_coordinate_names[:dim]
     ]
 
     #   Determine index field
@@ -379,7 +379,7 @@ def create_sparse_iteration_space(
         )
 
     spatial_counters = [
-        ctx.get_symbol(name, constify(ctx.index_dtype))
+        ctx.get_symbol(name, ctx.index_dtype)
         for name in DEFAULTS.spatial_counter_names[:dim]
     ]
 
diff --git a/src/pystencils/defaults.py b/src/pystencils/defaults.py
index c7ac33347..0b6a48af1 100644
--- a/src/pystencils/defaults.py
+++ b/src/pystencils/defaults.py
@@ -1,13 +1,17 @@
-from typing import TypeVar, Generic, Callable
-from .types import PsType, PsIeeeFloatType, PsIntegerType, PsSignedIntegerType, PsStructType
+from .types import (
+    PsIeeeFloatType,
+    PsIntegerType,
+    PsSignedIntegerType,
+    PsStructType,
+    UserTypeSpec,
+    create_type,
+)
 
-from pystencils.sympyextensions.typed_sympy import TypedSymbol
+from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
 
-SymbolT = TypeVar("SymbolT")
 
-
-class GenericDefaults(Generic[SymbolT]):
-    def __init__(self, symcreate: Callable[[str, PsType], SymbolT]):
+class SympyDefaults:
+    def __init__(self):
         self.numeric_dtype = PsIeeeFloatType(64)
         """Default data type for numerical computations"""
 
@@ -18,37 +22,38 @@ class GenericDefaults(Generic[SymbolT]):
         """Names of the default spatial counters"""
 
         self.spatial_counters = (
-            symcreate("ctr_0", self.index_dtype),
-            symcreate("ctr_1", self.index_dtype),
-            symcreate("ctr_2", self.index_dtype),
+            TypedSymbol("ctr_0", DynamicType.INDEX_TYPE),
+            TypedSymbol("ctr_1", DynamicType.INDEX_TYPE),
+            TypedSymbol("ctr_2", DynamicType.INDEX_TYPE),
         )
         """Default spatial counters"""
 
-        self._index_struct_coordinate_names = ("x", "y", "z")
+        self.index_struct_coordinate_names = ("x", "y", "z")
         """Default names of spatial coordinate members in index list structures"""
 
-        self.index_struct_coordinates = (
-            PsStructType.Member("x", self.index_dtype),
-            PsStructType.Member("y", self.index_dtype),
-            PsStructType.Member("z", self.index_dtype),
-        )
-        """Default spatial coordinate members in index list structures"""
-
         self.sparse_counter_name = "sparse_idx"
         """Name of the default sparse iteration counter"""
 
-        self.sparse_counter = symcreate(self.sparse_counter_name, self.index_dtype)
+        self.sparse_counter = TypedSymbol(
+            self.sparse_counter_name, DynamicType.INDEX_TYPE
+        )
         """Default sparse iteration counter."""
 
     def field_shape_name(self, field_name: str, coord: int):
         return f"_size_{field_name}_{coord}"
-    
+
     def field_stride_name(self, field_name: str, coord: int):
         return f"_stride_{field_name}_{coord}"
-    
+
     def field_pointer_name(self, field_name: str):
         return f"_data_{field_name}"
 
+    def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType:
+        idx_type = create_type(index_dtype)
+        return PsStructType(
+            [(name, idx_type) for name in self.index_struct_coordinate_names[:dim]]
+        )
+
 
-DEFAULTS = GenericDefaults[TypedSymbol](TypedSymbol)
+DEFAULTS = SympyDefaults()
 """Default names and symbols used throughout code generation"""
diff --git a/src/pystencils/py.typed b/src/pystencils/py.typed
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/kernelcreation/test_spatial_counters.py b/tests/kernelcreation/test_spatial_counters.py
new file mode 100644
index 000000000..fdb365294
--- /dev/null
+++ b/tests/kernelcreation/test_spatial_counters.py
@@ -0,0 +1,70 @@
+import pytest
+import numpy as np
+
+from pystencils import (
+    Field,
+    Assignment,
+    create_kernel,
+    CreateKernelConfig,
+    DEFAULTS,
+    FieldType,
+)
+from pystencils.sympyextensions import CastFunc
+
+
+@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
+def test_spatial_counters_dense(index_dtype):
+    #   Parametrized over index_dtype to make sure the `DynamicType.INDEX` in the
+    #   DEFAULTS works validly
+    x, y, z = DEFAULTS.spatial_counters
+
+    f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
+
+    asms = [
+        Assignment(f(0), CastFunc.as_numeric(z)),
+        Assignment(f(1), CastFunc.as_numeric(y)),
+        Assignment(f(2), CastFunc.as_numeric(x)),
+    ]
+
+    cfg = CreateKernelConfig(index_dtype=index_dtype)
+    kernel = create_kernel(asms, cfg).compile()
+
+    f_arr = np.zeros((16, 16, 16, 3))
+    kernel(f=f_arr)
+
+    expected = np.mgrid[0:16, 0:16, 0:16].astype(np.float64).transpose()
+
+    np.testing.assert_equal(f_arr, expected)
+
+
+@pytest.mark.parametrize("index_dtype", ["int16", "int32", "uint32", "int64"])
+def test_spatial_counters_sparse(index_dtype):
+    x, y, z = DEFAULTS.spatial_counters
+
+    f = Field.create_generic("f", 3, "float64", index_shape=(3,), layout="fzyx")
+
+    asms = [
+        Assignment(f(0), CastFunc.as_numeric(x)),
+        Assignment(f(1), CastFunc.as_numeric(y)),
+        Assignment(f(2), CastFunc.as_numeric(z)),
+    ]
+
+    idx_struct = DEFAULTS.index_struct(index_dtype, 3)
+    idx_field = Field.create_generic(
+        "index", 1, idx_struct, field_type=FieldType.INDEXED
+    )
+
+    cfg = CreateKernelConfig(index_dtype=index_dtype, index_field=idx_field)
+    kernel = create_kernel(asms, cfg).compile()
+
+    f_arr = np.zeros((16, 16, 16, 3))
+    idx_arr = np.array(
+        [(1, 4, 3), (5, 1, 6), (9, 5, 1), (3, 13, 7)], dtype=idx_struct.numpy_dtype
+    )
+
+    kernel(f=f_arr, index=idx_arr)
+
+    for t in idx_arr:
+        assert f_arr[t[0], t[1], t[2], 0] == t[0].astype(np.float64)
+        assert f_arr[t[0], t[1], t[2], 1] == t[1].astype(np.float64)
+        assert f_arr[t[0], t[1], t[2], 2] == t[2].astype(np.float64)
-- 
GitLab