From fad3752a84dfbdb22857d2066e8b835fb573458f Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 23 Jan 2025 09:03:40 +0100
Subject: [PATCH] WIP update field to permit DynamicType

---
 .../backend/kernelcreation/context.py         | 21 +++++-
 .../backend/kernelcreation/freeze.py          |  9 +--
 src/pystencils/field.py                       | 75 ++++++++++---------
 tests/frontend/test_field.py                  | 50 ++++++++++---
 4 files changed, 96 insertions(+), 59 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index 8f5931c64..68da893ff 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -93,6 +93,16 @@ class KernelCreationContext:
     def index_dtype(self) -> PsIntegerType:
         """Data type used by default for index expressions"""
         return self._index_dtype
+    
+    def resolve_dynamic_type(self, dtype: DynamicType | PsType) -> PsType:
+        """Selects the appropriate data type for `DynamicType` instances, and returns all other types as they are."""
+        match dtype:
+            case DynamicType.NUMERIC_TYPE:
+                return self._default_dtype
+            case DynamicType.INDEX_TYPE:
+                return self._index_dtype
+            case _:
+                return dtype
 
     @property
     def metadata(self) -> dict[str, Any]:
@@ -339,6 +349,8 @@ class KernelCreationContext:
             if isinstance(s, TypedSymbol)
         )
 
+        entry_type = self.resolve_dynamic_type(field.dtype)
+
         if len(idx_types) > 1:
             raise KernelConstraintsError(
                 f"Multiple incompatible types found in index symbols of field {field}: "
@@ -375,10 +387,10 @@ class KernelCreationContext:
 
         base_ptr = self.get_symbol(
             DEFAULTS.field_pointer_name(field.name),
-            PsPointerType(field.dtype, restrict=True),
+            PsPointerType(entry_type, restrict=True),
         )
 
-        return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
+        return PsBuffer(field.name, entry_type, base_ptr, buf_shape, buf_strides)
 
     def _create_buffer_field_buffer(self, field: Field) -> PsBuffer:
         if field.spatial_dimensions != 1:
@@ -418,10 +430,11 @@ class KernelCreationContext:
             ]
 
         buf_strides = [PsConstant(num_entries, idx_type), PsConstant(1, idx_type)]
+        buf_dtype = self.resolve_dynamic_type(field.dtype)
 
         base_ptr = self.get_symbol(
             DEFAULTS.field_pointer_name(field.name),
-            PsPointerType(field.dtype, restrict=True),
+            PsPointerType(buf_dtype, restrict=True),
         )
 
-        return PsBuffer(field.name, field.dtype, base_ptr, buf_shape, buf_strides)
+        return PsBuffer(field.name, buf_dtype, base_ptr, buf_shape, buf_strides)
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 16710861b..5d18ecf82 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -261,14 +261,7 @@ class FreezeExpressions:
         return num / denom
 
     def map_TypedSymbol(self, expr: TypedSymbol):
-        dtype = expr.dtype
-
-        match dtype:
-            case DynamicType.NUMERIC_TYPE:
-                dtype = self._ctx.default_dtype
-            case DynamicType.INDEX_TYPE:
-                dtype = self._ctx.index_dtype
-
+        dtype = self._ctx.resolve_dynamic_type(expr.dtype)
         symb = self._ctx.get_symbol(expr.name, dtype)
         return PsSymbolExpr(symb)
 
diff --git a/src/pystencils/field.py b/src/pystencils/field.py
index 1a3a13b73..826c551a7 100644
--- a/src/pystencils/field.py
+++ b/src/pystencils/field.py
@@ -12,13 +12,13 @@ import sympy as sp
 from sympy.core.cache import cacheit
 
 from .defaults import DEFAULTS
-from pystencils.alignedarray import aligned_empty
-from pystencils.spatial_coordinates import x_staggered_vector, x_vector
-from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
-from pystencils.types import PsType, PsStructType, create_type
-from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
-from pystencils.sympyextensions import is_integer_sequence
-from pystencils.types import UserTypeSpec
+from .alignedarray import aligned_empty
+from .spatial_coordinates import x_staggered_vector, x_vector
+from .stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
+from .types import PsType, PsStructType, create_type
+from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
+from .sympyextensions import is_integer_sequence
+from .types import UserTypeSpec
 
 
 __all__ = ['Field', 'fields', 'FieldType', 'Field']
@@ -123,16 +123,16 @@ class Field:
         >>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)];
 
     Args:
-        field_name: something
-        field_type: something
-        dtype: something
-        layout: something
-        shape: something
-        strides: something
+        field_name: The field's name
+        field_type: The kind of the field
+        dtype: Data type of the field's entries
+        layout: Linearization order of the field's spatial dimensions
+        shape: Total shape (spatial and index) of the field
+        strides: Linearization strides of the field
     """
 
     @staticmethod
-    def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec = np.float64, index_dimensions=0, 
+    def create_generic(field_name, spatial_dimensions, dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, index_dimensions=0, 
                        layout='numpy', index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
         """
         Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
@@ -177,15 +177,15 @@ class Field:
             for i in range(total_dimensions)
         ])
 
-        dtype = create_type(dtype)
-        np_data_type = dtype.numpy_dtype
-        assert np_data_type is not None
-        
-        if np_data_type.fields is not None:
+        if not isinstance(dtype, DynamicType):
+            dtype = create_type(dtype)
+
+        if isinstance(dtype, PsStructType):
             if index_dimensions != 0:
                 raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
             shape += (1,)
             strides += (1,)
+
         if field_type == FieldType.STAGGERED and index_dimensions == 0:
             raise ValueError("A staggered field needs at least one index dimension")
 
@@ -228,7 +228,7 @@ class Field:
 
     @staticmethod
     def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
-                          dtype: UserTypeSpec = np.float64, layout: str = 'numpy',
+                          dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE, layout: str = 'numpy',
                           strides: Optional[Sequence[int]] = None,
                           field_type=FieldType.GENERIC) -> 'Field':
         """
@@ -256,11 +256,10 @@ class Field:
             assert len(strides) == len(shape)
             strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
 
-        dtype = create_type(dtype)
-        numpy_dtype = dtype.numpy_dtype
-        assert numpy_dtype is not None
+        if not isinstance(dtype, DynamicType):
+            dtype = create_type(dtype)
 
-        if numpy_dtype.fields is not None:
+        if isinstance(dtype, PsStructType):
             if index_dimensions != 0:
                 raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
             shape += (1,)
@@ -277,7 +276,7 @@ class Field:
         self,
         field_name: str,
         field_type: FieldType,
-        dtype: UserTypeSpec,
+        dtype: UserTypeSpec | DynamicType,
         layout: tuple[int, ...],
         shape,
         strides
@@ -287,7 +286,7 @@ class Field:
         assert isinstance(field_type, FieldType)
         assert len(shape) == len(strides)
         self.field_type = field_type
-        self._dtype = create_type(dtype)
+        self._dtype: PsType | DynamicType = create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
         self._layout = normalize_layout(layout)
         self.shape = shape
         self.strides = strides
@@ -351,12 +350,15 @@ class Field:
         return self.strides[self.spatial_dimensions:]
 
     @property
-    def dtype(self) -> PsType:
+    def dtype(self) -> PsType | DynamicType:
         return self._dtype
 
     @property
-    def itemsize(self):
-        return self.dtype.itemsize
+    def itemsize(self) -> int | None:
+        if isinstance(self.dtype, PsType):
+            return self.dtype.itemsize
+        else:
+            return None
 
     def __repr__(self):
         if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
@@ -1094,17 +1096,20 @@ def _parse_description(description):
         result = type_description_regex.match(d)
         if result:
             data_type_str, size_info = result.group(1), result.group(2).strip().lower()
-            if data_type_str is None:
-                data_type_str = 'float64'
-            data_type_str = data_type_str.lower().strip()
+            if data_type_str is not None:
+                data_type_str = data_type_str.lower().strip()
+
+            if data_type_str:
+                dtype = create_type(data_type_str)
+            else:
+                dtype = DynamicType.NUMERIC_TYPE
 
-            if not data_type_str:
-                data_type_str = 'float64'
             if size_info.endswith('d'):
                 size_info = int(size_info[:-1])
             else:
                 size_info = tuple(int(e) for e in size_info.split(","))
-            return data_type_str, size_info
+            
+            return dtype, size_info
         else:
             raise ValueError("Could not parse field description")
 
diff --git a/tests/frontend/test_field.py b/tests/frontend/test_field.py
index dc804491b..6521e114f 100644
--- a/tests/frontend/test_field.py
+++ b/tests/frontend/test_field.py
@@ -3,7 +3,7 @@ import pytest
 import sympy as sp
 
 import pystencils as ps
-from pystencils import DEFAULTS
+from pystencils import DEFAULTS, DynamicType, create_type, fields
 from pystencils.field import (
     Field,
     FieldType,
@@ -15,6 +15,7 @@ from pystencils.field import (
 def test_field_basic():
     f = Field.create_generic("f", spatial_dimensions=2)
     assert FieldType.is_generic(f)
+    assert f.dtype == DynamicType.NUMERIC_TYPE
     assert f["E"] == f[1, 0]
     assert f["N"] == f[0, 1]
     assert "_" in f.center._latex("dummy")
@@ -41,17 +42,16 @@ def test_field_basic():
     assert f1.ndim == f.ndim
     assert f1.values_per_cell() == f.values_per_cell()
 
-    fixed = ps.fields("f(5, 5) : double[20, 20]")
-    assert fixed.neighbor_vector((1, 1)).shape == (5, 5)
-
     f = Field.create_fixed_size("f", (10, 10), strides=(80, 8), dtype=np.float64)
     assert f.spatial_strides == (10, 1)
     assert f.index_strides == ()
     assert f.center_vector == sp.Matrix([f.center])
+    assert f.dtype == create_type("float64")
 
     f1 = f.new_field_with_different_name("f1")
     assert f1.ndim == f.ndim
     assert f1.values_per_cell() == f.values_per_cell()
+    assert f1.dtype == create_type("float64")
 
     f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2)
     assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]])
@@ -61,16 +61,42 @@ def test_field_basic():
     neighbor = field_access.neighbor(coord_id=0, offset=-2)
     assert neighbor.offsets == (-1, 1)
     assert "_" in neighbor._latex("dummy")
+    assert f.dtype == DynamicType.NUMERIC_TYPE
 
     f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
     assert f.center_vector == sp.Array(
         [[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
     )
+    assert f.dtype == DynamicType.NUMERIC_TYPE
 
     f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
     field_access = f[1, -1, 2, -3, 0](1, 0)
     assert field_access.offsets == (1, -1, 2, -3, 0)
     assert field_access.index == (1, 0)
+    assert f.dtype == DynamicType.NUMERIC_TYPE
+
+
+def test_field_description_parsing():
+    f, g = fields("f(1), g(3): [2D]")
+
+    assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE
+    assert f.spatial_dimensions == g.spatial_dimensions == 2
+    assert f.index_shape == (1,)
+    assert g.index_shape == (3,)
+
+    h = fields("h: float32[3D]")
+
+    assert h.index_shape == ()
+    assert h.spatial_dimensions == 3
+    assert h.index_dimensions == 0
+    assert h.dtype == create_type("float32")
+
+    f: Field = fields("f(5, 5) : double[20, 20]")
+    
+    assert f.dtype == create_type("float64")
+    assert f.spatial_shape == (20, 20)
+    assert f.index_shape == (5, 5)
+    assert f.neighbor_vector((1, 1)).shape == (5, 5)
 
 
 def test_error_handling():
@@ -145,7 +171,7 @@ def test_error_handling():
 
 
 def test_decorator_scoping():
-    dst = ps.fields("dst : double[2D]")
+    dst = fields("dst : double[2D]")
 
     def f1():
         a = sp.Symbol("a")
@@ -165,7 +191,7 @@ def test_decorator_scoping():
 
 
 def test_string_creation():
-    x, y, z = ps.fields("  x(4),    y(3,5) z : double[  3,  47]")
+    x, y, z = fields("  x(4),    y(3,5) z : double[  3,  47]")
     assert x.index_shape == (4,)
     assert y.index_shape == (3, 5)
     assert z.spatial_shape == (3, 47)
@@ -173,9 +199,9 @@ def test_string_creation():
 
 def test_itemsize():
 
-    x = ps.fields("x: float32[1d]")
-    y = ps.fields("y:  float64[2d]")
-    i = ps.fields("i:  int16[1d]")
+    x = fields("x: float32[1d]")
+    y = fields("y:  float64[2d]")
+    i = fields("i:  int16[1d]")
 
     assert x.itemsize == 4
     assert y.itemsize == 8
@@ -249,7 +275,7 @@ def test_memory_layout_descriptors():
 def test_staggered():
 
     # D2Q5
-    j1, j2, j3 = ps.fields(
+    j1, j2, j3 = fields(
         "j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED
     )
 
@@ -296,7 +322,7 @@ def test_staggered():
     )
 
     # D2Q9
-    k1, k2 = ps.fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
+    k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
 
     assert k1[1, 1](2) == k1.staggered_access("NE")
     assert k1[0, 0](2) == k1.staggered_access("SW")
@@ -319,7 +345,7 @@ def test_staggered():
     ]
 
     # sign reversed when using as flux field
-    r = ps.fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
+    r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
     assert r[0, 0](0) == r.staggered_access("W")
     assert -r[1, 0](0) == r.staggered_access("E")
 
-- 
GitLab