From d85f682c2b37be3e7a3ca101f50c894c5247eab5 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 8 Jul 2024 21:24:50 +0200
Subject: [PATCH] some extensions to the type system

---
 src/pystencils/__init__.py                    |  3 +-
 .../backend/kernelcreation/freeze.py          | 18 ++++--
 src/pystencils/sympyextensions/typed_sympy.py | 60 ++++++++++++++-----
 src/pystencils/types/parsing.py               |  2 +
 src/pystencils/types/types.py                 |  2 +-
 tests/nbackend/kernelcreation/test_freeze.py  | 53 ++++++++++++----
 tests/nbackend/types/test_types.py            | 11 ++++
 7 files changed, 115 insertions(+), 34 deletions(-)

diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index 3d3b7846a..c39cd3b82 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -6,7 +6,7 @@ from . import fd
 from . import stencil as stencil
 from .display_utils import get_code_obj, get_code_str, show_code, to_dot
 from .field import Field, FieldType, fields
-from .types import create_type
+from .types import create_type, create_numeric_type
 from .cache import clear_cache
 from .config import (
     CreateKernelConfig,
@@ -41,6 +41,7 @@ __all__ = [
     "DEFAULTS",
     "TypedSymbol",
     "create_type",
+    "create_numeric_type",
     "make_slice",
     "CreateKernelConfig",
     "CpuOptimConfig",
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 3865db38f..59fa04b3b 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -7,13 +7,12 @@ import sympy.core.relational
 import sympy.logic.boolalg
 from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
 
+from ...sympyextensions.astnodes import Assignment, AssignmentCollection
 from ...sympyextensions import (
-    Assignment,
-    AssignmentCollection,
     integer_functions,
     ConditionalFieldAccess,
 )
-from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
+from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
 from ...sympyextensions.pointers import AddressOf
 from ...field import Field, FieldType
 
@@ -58,7 +57,7 @@ from ..ast.expressions import (
 )
 
 from ..constants import PsConstant
-from ...types import PsStructType
+from ...types import PsStructType, PsType
 from ..exceptions import PsInputError
 from ..functions import PsMathFunction, MathFunctions
 
@@ -465,7 +464,16 @@ class FreezeExpressions:
         return cast(PsCall, args[0])
 
     def map_CastFunc(self, cast_expr: CastFunc) -> PsCast:
-        return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr))
+        dtype: PsType
+        match cast_expr.dtype:
+            case DynamicType.NUMERIC_TYPE:
+                dtype = self._ctx.default_dtype
+            case DynamicType.INDEX_TYPE:
+                dtype = self._ctx.index_dtype
+            case other if isinstance(other, PsType):
+                dtype = other
+
+        return PsCast(dtype, self.visit_expr(cast_expr.expr))
 
     def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
         arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index e022db511..cd5c80c88 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
 import sympy as sp
+from enum import Enum, auto
 
-from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, create_type
+from ..types import PsType, PsNumericType, PsPointerType, PsBoolType, PsIntegerType, create_type
 
 
 def assumptions_from_dtype(dtype: PsType):
@@ -33,20 +36,28 @@ def is_loop_counter_symbol(symbol):
         return None
 
 
+class DynamicType(Enum):
+    NUMERIC_TYPE = auto()
+    INDEX_TYPE = auto()
+
+
 class PsTypeAtom(sp.Atom):
     """Wrapper around a PsType to disguise it as a SymPy atom."""
 
     def __new__(cls, *args, **kwargs):
         return sp.Basic.__new__(cls)
     
-    def __init__(self, dtype: PsType) -> None:
+    def __init__(self, dtype: PsType | DynamicType) -> None:
         self._dtype = dtype
 
     def _sympystr(self, *args, **kwargs):
         return str(self._dtype)
 
-    def get(self) -> PsType:
+    def get(self) -> PsType | DynamicType:
         return self._dtype
+    
+    def _hashable_content(self):
+        return (self._dtype, )
 
 
 class TypedSymbol(sp.Symbol):
@@ -105,12 +116,15 @@ class FieldStrideSymbol(TypedSymbol):
         obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_name: str, coordinate: int):
+    def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None):
         from ..defaults import DEFAULTS
+        
+        if dtype is None:
+            dtype = DEFAULTS.index_dtype
 
         name = f"_stride_{field_name}_{coordinate}"
         obj = super(FieldStrideSymbol, cls).__xnew__(
-            cls, name, DEFAULTS.index_dtype, positive=True
+            cls, name, dtype, positive=True
         )
         obj.field_name = field_name
         obj.coordinate = coordinate
@@ -138,12 +152,15 @@ class FieldShapeSymbol(TypedSymbol):
         obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_name: str, coordinate: int):
+    def __new_stage2__(cls, field_name: str, coordinate: int, dtype: PsIntegerType | None = None):
         from ..defaults import DEFAULTS
+        
+        if dtype is None:
+            dtype = DEFAULTS.index_dtype
 
         name = f"_size_{field_name}_{coordinate}"
         obj = super(FieldShapeSymbol, cls).__xnew__(
-            cls, name, DEFAULTS.index_dtype, positive=True
+            cls, name, dtype, positive=True
         )
         obj.field_name = field_name
         obj.coordinate = coordinate
@@ -190,10 +207,21 @@ class FieldPointerSymbol(TypedSymbol):
 
 
 class CastFunc(sp.Function):
+    """Use this function to introduce a static type cast into the output code.
+
+    Usage: ``CastFunc(expr, target_type)`` becomes, in C code, ``(target_type) expr``.
+    The `target_type` may be a valid pystencils type specification parsable by `create_type`,
+    or a special value of the `DynamicType` enum.
+    These dynamic types can be used to select the target type according to the code generation context.
     """
-    CastFunc is used in order to introduce static casts. They are especially useful as a way to signal what type
-    a certain node should have, if it is impossible to add a type to a node, e.g. a sp.Number.
-    """
+
+    @staticmethod
+    def as_numeric(expr):
+        return CastFunc(expr, DynamicType.NUMERIC_TYPE)
+    
+    @staticmethod
+    def as_index(expr):
+        return CastFunc(expr, DynamicType.INDEX_TYPE)
 
     is_Atom = True
 
@@ -207,8 +235,12 @@ class CastFunc(sp.Function):
         if expr.__class__ == CastFunc:
             expr = expr.args[0]
 
-        if not isinstance(dtype, PsTypeAtom):
-            dtype = PsTypeAtom(create_type(dtype))
+        if not isinstance(dtype, (PsTypeAtom)):
+            if isinstance(dtype, DynamicType):
+                dtype = PsTypeAtom(dtype)
+            else:
+                dtype = PsTypeAtom(create_type(dtype))
+                
         # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
         # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
         # to problems when for example comparing cast_func's for equality
@@ -236,7 +268,7 @@ class CastFunc(sp.Function):
         return self.args[0].is_commutative
 
     @property
-    def dtype(self) -> PsType:
+    def dtype(self) -> PsType | DynamicType:
         assert isinstance(self.args[1], PsTypeAtom)
         return self.args[1].get()
 
@@ -246,7 +278,7 @@ class CastFunc(sp.Function):
 
     @property
     def is_integer(self):
-        if isinstance(self.dtype, PsNumericType):
+        if isinstance(self.dtype, PsNumericType) or self.dtype == DynamicType.INDEX_TYPE:
             return self.dtype.is_int() or super().is_integer
         else:
             return super().is_integer
diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py
index 75fb35d22..d6522e5bb 100644
--- a/src/pystencils/types/parsing.py
+++ b/src/pystencils/types/parsing.py
@@ -158,6 +158,8 @@ def parse_type_name(typename: str, const: bool):
         case "uint8" | "uint8_t":
             return PsUnsignedIntegerType(8, const=const)
 
+        case "half" | "float16":
+            return PsIeeeFloatType(16, const=const)
         case "float" | "float32":
             return PsIeeeFloatType(32, const=const)
         case "double" | "float64":
diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index 2f0f2ff46..658225762 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -200,7 +200,7 @@ class PsStructType(PsType):
     @property
     def numpy_dtype(self) -> np.dtype:
         members = [(m.name, m.dtype.numpy_dtype) for m in self._members]
-        return np.dtype(members)
+        return np.dtype(members, align=True)
 
     @property
     def itemsize(self) -> int:
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index b22df7d0b..f16a468e7 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -1,7 +1,8 @@
 import sympy as sp
 import pytest
 
-from pystencils import Assignment, fields
+from pystencils import Assignment, fields, create_type, create_numeric_type
+from pystencils.sympyextensions import CastFunc
 
 from pystencils.backend.ast.structural import (
     PsAssignment,
@@ -26,7 +27,8 @@ from pystencils.backend.ast.expressions import (
     PsLe,
     PsGt,
     PsGe,
-    PsCall
+    PsCall,
+    PsCast,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import PsMathFunction, MathFunctions
@@ -182,14 +184,17 @@ def test_freeze_booleans():
     assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2))
 
 
-@pytest.mark.parametrize("rel_pair", [
-    (sp.Eq, PsEq),
-    (sp.Ne, PsNe),
-    (sp.Lt, PsLt),
-    (sp.Gt, PsGt),
-    (sp.Le, PsLe),
-    (sp.Ge, PsGe)
-])
+@pytest.mark.parametrize(
+    "rel_pair",
+    [
+        (sp.Eq, PsEq),
+        (sp.Ne, PsNe),
+        (sp.Lt, PsLt),
+        (sp.Gt, PsGt),
+        (sp.Le, PsLe),
+        (sp.Ge, PsGe),
+    ],
+)
 def test_freeze_relations(rel_pair):
     ctx = KernelCreationContext()
     freeze = FreezeExpressions(ctx)
@@ -211,7 +216,7 @@ def test_freeze_piecewise():
     freeze = FreezeExpressions(ctx)
 
     p, q, x, y, z = sp.symbols("p, q, x, y, z")
-    
+
     p2 = PsExpression.make(ctx.get_symbol("p"))
     q2 = PsExpression.make(ctx.get_symbol("q"))
     x2 = PsExpression.make(ctx.get_symbol("x"))
@@ -222,10 +227,10 @@ def test_freeze_piecewise():
     expr = freeze(piecewise)
 
     assert isinstance(expr, PsTernary)
-    
+
     should = PsTernary(p2, x2, PsTernary(q2, y2, z2))
     assert expr.structurally_equal(should)
-    
+
     piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q)))
     with pytest.raises(FreezeError):
         freeze(piecewise)
@@ -259,3 +264,25 @@ def test_multiarg_min_max():
 
     expr = freeze(sp.Max(w, x, y, z))
     assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
+
+
+def test_cast_func():
+    ctx = KernelCreationContext(
+        default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16")
+    )
+    freeze = FreezeExpressions(ctx)
+
+    x, y, z = sp.symbols("x, y, z")
+
+    x2 = PsExpression.make(ctx.get_symbol("x"))
+    y2 = PsExpression.make(ctx.get_symbol("y"))
+    z2 = PsExpression.make(ctx.get_symbol("z"))
+
+    expr = freeze(CastFunc(x, create_type("int")))
+    assert expr.structurally_equal(PsCast(create_type("int"), x2))
+
+    expr = freeze(CastFunc.as_numeric(y))
+    assert expr.structurally_equal(PsCast(ctx.default_dtype, y2))
+
+    expr = freeze(CastFunc.as_index(z))
+    assert expr.structurally_equal(PsCast(ctx.index_dtype, z2))
diff --git a/tests/nbackend/types/test_types.py b/tests/nbackend/types/test_types.py
index 39f89e6fe..1cc2ae0e4 100644
--- a/tests/nbackend/types/test_types.py
+++ b/tests/nbackend/types/test_types.py
@@ -139,6 +139,17 @@ def test_struct_types():
     with pytest.raises(PsTypeError):
         t.c_string()
 
+    t = PsStructType([
+        ("a", SInt(8)),
+        ("b", SInt(16)),
+        ("c", SInt(64))
+    ])
+
+    #   Check that natural alignment is taken into account
+    numpy_type = np.dtype([("a", "i1"), ("b", "i2"), ("c", "i8")], align=True)
+    assert t.numpy_dtype == numpy_type
+    assert t.itemsize == numpy_type.itemsize == 16
+
 
 def test_pickle():
     types = [
-- 
GitLab