Skip to content
Snippets Groups Projects
Commit ef41f461 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add typed_sympy tests

parent d85f682c
No related branches found
No related tags found
1 merge request!400Extensions and fixes to the type system
Pipeline #67440 passed
......@@ -41,8 +41,8 @@ class DynamicType(Enum):
INDEX_TYPE = auto()
class PsTypeAtom(sp.Atom):
"""Wrapper around a PsType to disguise it as a SymPy atom."""
class TypeAtom(sp.Atom):
"""Wrapper around a type to disguise it as a SymPy atom."""
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
......@@ -74,7 +74,7 @@ class TypedSymbol(sp.Symbol):
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
obj._dtype = create_type(dtype)
obj._dtype = dtype
return obj
......@@ -235,11 +235,11 @@ class CastFunc(sp.Function):
if expr.__class__ == CastFunc:
expr = expr.args[0]
if not isinstance(dtype, (PsTypeAtom)):
if not isinstance(dtype, (TypeAtom)):
if isinstance(dtype, DynamicType):
dtype = PsTypeAtom(dtype)
dtype = TypeAtom(dtype)
else:
dtype = PsTypeAtom(create_type(dtype))
dtype = TypeAtom(create_type(dtype))
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
......@@ -269,7 +269,7 @@ class CastFunc(sp.Function):
@property
def dtype(self) -> PsType | DynamicType:
assert isinstance(self.args[1], PsTypeAtom)
assert isinstance(self.args[1], TypeAtom)
return self.args[1].get()
@property
......
......@@ -72,7 +72,7 @@ class PsPointerType(PsDereferencableType):
__match_args__ = ("base_type",)
def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False):
def __init__(self, base_type: PsType, restrict: bool = False, const: bool = False):
super().__init__(base_type, const)
self._restrict = restrict
......@@ -94,7 +94,7 @@ class PsPointerType(PsDereferencableType):
return f"{base_str} *{restrict_str} {self._const_string()}"
def __repr__(self) -> str:
return f"PsPointerType( {repr(self.base_type)}, const={self.const} )"
return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )"
class PsArrayType(PsDereferencableType):
......
import numpy as np
from pystencils.sympyextensions.typed_sympy import (
TypedSymbol,
CastFunc,
TypeAtom,
DynamicType,
)
from pystencils.types import create_type
from pystencils.types.quick import UInt, Ptr
def test_type_atoms():
atom1 = TypeAtom(create_type("int32"))
atom2 = TypeAtom(create_type("int32"))
assert atom1 == atom2
atom3 = TypeAtom(create_type("const int32"))
assert atom1 != atom3
atom4 = TypeAtom(DynamicType.INDEX_TYPE)
atom5 = TypeAtom(DynamicType.NUMERIC_TYPE)
assert atom3 != atom4
assert atom4 != atom5
def test_typed_symbol():
x = TypedSymbol("x", "uint32")
x2 = TypedSymbol("x", "uint64 *")
z = TypedSymbol("z", "float32")
assert x == TypedSymbol("x", np.uint32)
assert x != x2
assert x.dtype == UInt(32)
assert x2.dtype == Ptr(UInt(64))
assert x.is_integer
assert x.is_nonnegative
assert not x2.is_integer
assert z.is_real
assert not z.is_nonnegative
def test_cast_func():
assert (
CastFunc(TypedSymbol("s", np.uint), np.int64).canonical
== TypedSymbol("s", np.uint).canonical
)
a = CastFunc(5, np.uint)
assert a.is_negative is False
assert a.is_nonnegative
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment