From ef41f461073e2a0910e7e8d236ec38f41d09ecb3 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 8 Jul 2024 21:48:54 +0200 Subject: [PATCH] add typed_sympy tests --- src/pystencils/sympyextensions/typed_sympy.py | 14 ++--- src/pystencils/types/types.py | 4 +- tests/symbolics/test_typed_sympy.py | 57 +++++++++++++++++++ 3 files changed, 66 insertions(+), 9 deletions(-) create mode 100644 tests/symbolics/test_typed_sympy.py diff --git a/src/pystencils/sympyextensions/typed_sympy.py b/src/pystencils/sympyextensions/typed_sympy.py index cd5c80c88..611e5e7ac 100644 --- a/src/pystencils/sympyextensions/typed_sympy.py +++ b/src/pystencils/sympyextensions/typed_sympy.py @@ -41,8 +41,8 @@ class DynamicType(Enum): INDEX_TYPE = auto() -class PsTypeAtom(sp.Atom): - """Wrapper around a PsType to disguise it as a SymPy atom.""" +class TypeAtom(sp.Atom): + """Wrapper around a type to disguise it as a SymPy atom.""" def __new__(cls, *args, **kwargs): return sp.Basic.__new__(cls) @@ -74,7 +74,7 @@ class TypedSymbol(sp.Symbol): assumptions.update(kwargs) obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions) - obj._dtype = create_type(dtype) + obj._dtype = dtype return obj @@ -235,11 +235,11 @@ class CastFunc(sp.Function): if expr.__class__ == CastFunc: expr = expr.args[0] - if not isinstance(dtype, (PsTypeAtom)): + if not isinstance(dtype, (TypeAtom)): if isinstance(dtype, DynamicType): - dtype = PsTypeAtom(dtype) + dtype = TypeAtom(dtype) else: - dtype = PsTypeAtom(create_type(dtype)) + dtype = TypeAtom(create_type(dtype)) # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads @@ -269,7 +269,7 @@ class CastFunc(sp.Function): @property def dtype(self) -> PsType | DynamicType: - assert isinstance(self.args[1], PsTypeAtom) + assert isinstance(self.args[1], TypeAtom) return self.args[1].get() @property diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 658225762..61e3d73fd 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -72,7 +72,7 @@ class PsPointerType(PsDereferencableType): __match_args__ = ("base_type",) - def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False): + def __init__(self, base_type: PsType, restrict: bool = False, const: bool = False): super().__init__(base_type, const) self._restrict = restrict @@ -94,7 +94,7 @@ class PsPointerType(PsDereferencableType): return f"{base_str} *{restrict_str} {self._const_string()}" def __repr__(self) -> str: - return f"PsPointerType( {repr(self.base_type)}, const={self.const} )" + return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )" class PsArrayType(PsDereferencableType): diff --git a/tests/symbolics/test_typed_sympy.py b/tests/symbolics/test_typed_sympy.py new file mode 100644 index 000000000..41015f96b --- /dev/null +++ b/tests/symbolics/test_typed_sympy.py @@ -0,0 +1,57 @@ +import numpy as np + +from pystencils.sympyextensions.typed_sympy import ( + TypedSymbol, + CastFunc, + TypeAtom, + DynamicType, +) +from pystencils.types import create_type +from pystencils.types.quick import UInt, Ptr + + +def test_type_atoms(): + atom1 = TypeAtom(create_type("int32")) + atom2 = TypeAtom(create_type("int32")) + + assert atom1 == atom2 + + atom3 = TypeAtom(create_type("const int32")) + assert atom1 != atom3 + + atom4 = TypeAtom(DynamicType.INDEX_TYPE) + atom5 = TypeAtom(DynamicType.NUMERIC_TYPE) + + assert atom3 != atom4 + assert atom4 != atom5 + + +def test_typed_symbol(): + x = TypedSymbol("x", "uint32") + x2 = TypedSymbol("x", "uint64 *") + z = TypedSymbol("z", "float32") + + assert x == TypedSymbol("x", np.uint32) + assert x != x2 + + assert x.dtype == UInt(32) + assert x2.dtype == Ptr(UInt(64)) + + assert x.is_integer + assert x.is_nonnegative + + assert not x2.is_integer + + assert z.is_real + assert not z.is_nonnegative + + +def test_cast_func(): + assert ( + CastFunc(TypedSymbol("s", np.uint), np.int64).canonical + == TypedSymbol("s", np.uint).canonical + ) + + a = CastFunc(5, np.uint) + assert a.is_negative is False + assert a.is_nonnegative -- GitLab