From ef41f461073e2a0910e7e8d236ec38f41d09ecb3 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 8 Jul 2024 21:48:54 +0200
Subject: [PATCH] add typed_sympy tests

---
 src/pystencils/sympyextensions/typed_sympy.py | 14 ++---
 src/pystencils/types/types.py                 |  4 +-
 tests/symbolics/test_typed_sympy.py           | 57 +++++++++++++++++++
 3 files changed, 66 insertions(+), 9 deletions(-)
 create mode 100644 tests/symbolics/test_typed_sympy.py

diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py
index cd5c80c88..611e5e7ac 100644
--- a/src/pystencils/sympyextensions/typed_sympy.py
+++ b/src/pystencils/sympyextensions/typed_sympy.py
@@ -41,8 +41,8 @@ class DynamicType(Enum):
     INDEX_TYPE = auto()
 
 
-class PsTypeAtom(sp.Atom):
-    """Wrapper around a PsType to disguise it as a SymPy atom."""
+class TypeAtom(sp.Atom):
+    """Wrapper around a type to disguise it as a SymPy atom."""
 
     def __new__(cls, *args, **kwargs):
         return sp.Basic.__new__(cls)
@@ -74,7 +74,7 @@ class TypedSymbol(sp.Symbol):
         assumptions.update(kwargs)
 
         obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
-        obj._dtype = create_type(dtype)
+        obj._dtype = dtype
 
         return obj
 
@@ -235,11 +235,11 @@ class CastFunc(sp.Function):
         if expr.__class__ == CastFunc:
             expr = expr.args[0]
 
-        if not isinstance(dtype, (PsTypeAtom)):
+        if not isinstance(dtype, (TypeAtom)):
             if isinstance(dtype, DynamicType):
-                dtype = PsTypeAtom(dtype)
+                dtype = TypeAtom(dtype)
             else:
-                dtype = PsTypeAtom(create_type(dtype))
+                dtype = TypeAtom(create_type(dtype))
                 
         # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
         # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
@@ -269,7 +269,7 @@ class CastFunc(sp.Function):
 
     @property
     def dtype(self) -> PsType | DynamicType:
-        assert isinstance(self.args[1], PsTypeAtom)
+        assert isinstance(self.args[1], TypeAtom)
         return self.args[1].get()
 
     @property
diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index 658225762..61e3d73fd 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -72,7 +72,7 @@ class PsPointerType(PsDereferencableType):
 
     __match_args__ = ("base_type",)
 
-    def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False):
+    def __init__(self, base_type: PsType, restrict: bool = False, const: bool = False):
         super().__init__(base_type, const)
         self._restrict = restrict
 
@@ -94,7 +94,7 @@ class PsPointerType(PsDereferencableType):
         return f"{base_str} *{restrict_str} {self._const_string()}"
 
     def __repr__(self) -> str:
-        return f"PsPointerType( {repr(self.base_type)}, const={self.const} )"
+        return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )"
 
 
 class PsArrayType(PsDereferencableType):
diff --git a/tests/symbolics/test_typed_sympy.py b/tests/symbolics/test_typed_sympy.py
new file mode 100644
index 000000000..41015f96b
--- /dev/null
+++ b/tests/symbolics/test_typed_sympy.py
@@ -0,0 +1,57 @@
+import numpy as np
+
+from pystencils.sympyextensions.typed_sympy import (
+    TypedSymbol,
+    CastFunc,
+    TypeAtom,
+    DynamicType,
+)
+from pystencils.types import create_type
+from pystencils.types.quick import UInt, Ptr
+
+
+def test_type_atoms():
+    atom1 = TypeAtom(create_type("int32"))
+    atom2 = TypeAtom(create_type("int32"))
+
+    assert atom1 == atom2
+
+    atom3 = TypeAtom(create_type("const int32"))
+    assert atom1 != atom3
+
+    atom4 = TypeAtom(DynamicType.INDEX_TYPE)
+    atom5 = TypeAtom(DynamicType.NUMERIC_TYPE)
+
+    assert atom3 != atom4
+    assert atom4 != atom5
+
+
+def test_typed_symbol():
+    x = TypedSymbol("x", "uint32")
+    x2 = TypedSymbol("x", "uint64 *")
+    z = TypedSymbol("z", "float32")
+
+    assert x == TypedSymbol("x", np.uint32)
+    assert x != x2
+
+    assert x.dtype == UInt(32)
+    assert x2.dtype == Ptr(UInt(64))
+
+    assert x.is_integer
+    assert x.is_nonnegative
+
+    assert not x2.is_integer
+
+    assert z.is_real
+    assert not z.is_nonnegative
+
+
+def test_cast_func():
+    assert (
+        CastFunc(TypedSymbol("s", np.uint), np.int64).canonical
+        == TypedSymbol("s", np.uint).canonical
+    )
+
+    a = CastFunc(5, np.uint)
+    assert a.is_negative is False
+    assert a.is_nonnegative
-- 
GitLab