Skip to content
Snippets Groups Projects
Commit 095436ed authored by Frederik Hennig's avatar Frederik Hennig
Browse files

additional tests & some fixes

parent 80f82607
No related branches found
No related tags found
No related merge requests found
Pipeline #61471 failed
......@@ -32,15 +32,15 @@ class FreezeExpressions(SympyToPymbolicMapper):
@overload
def __call__(self, asms: AssignmentCollection) -> PsBlock:
...
pass
@overload
def __call__(self, expr: sp.Expr) -> PsExpression:
...
pass
@overload
def __call__(self, expr: Assignment) -> PsAssignment:
...
def __call__(self, asm: Assignment) -> PsAssignment:
pass
def __call__(self, obj):
if isinstance(obj, AssignmentCollection):
......
......@@ -171,8 +171,9 @@ class Typifier(Mapper):
def typify_expression(
self, expr: Any, target_type: PsNumericType | None = None
) -> ExprOrConstant:
return self.rec(expr, TypeContext(target_type))
) -> tuple[ExprOrConstant, PsNumericType]:
tc = TypeContext(target_type)
return self.rec(expr, tc)
# Leaf nodes: Variables, Typed Variables, Constants and TypedConstants
......
......@@ -13,7 +13,7 @@ from .basic_types import (
deconstify,
)
from .quick import make_type
from .quick import make_type, make_numeric_type
from .exception import PsTypeError
......@@ -31,5 +31,6 @@ __all__ = [
"constify",
"deconstify",
"make_type",
"make_numeric_type",
"PsTypeError",
]
......@@ -208,10 +208,9 @@ class PsStructType(PsAbstractType):
def _c_string(self) -> str:
if self._name is None:
# raise PsInternalCompilerError(
# "Cannot retrieve C string for anonymous struct type"
# )
return "<anonymous>"
raise PsInternalCompilerError(
"Cannot retrieve C string for anonymous struct type"
)
return self._name
def __eq__(self, other: object) -> bool:
......@@ -502,6 +501,8 @@ class PsIeeeFloatType(PsScalarType):
def _c_string(self) -> str:
match self._width:
case 16:
return f"{self._const_string()}half"
case 32:
return f"{self._const_string()}float"
case 64:
......
......@@ -34,6 +34,8 @@ def interpret_python_type(t: type) -> PsAbstractType:
if t is np.int64:
return PsSignedIntegerType(64)
if t is np.float16:
return PsIeeeFloatType(16)
if t is np.float32:
return PsIeeeFloatType(32)
if t is np.float64:
......
......@@ -11,6 +11,7 @@ import numpy as np
from .basic_types import (
PsAbstractType,
PsCustomType,
PsNumericType,
PsScalarType,
PsPointerType,
PsIntegerType,
......@@ -39,11 +40,7 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
- Instances of `PsAbstractType` will be returned as they are
"""
from .parsing import (
parse_type_string,
interpret_python_type,
interpret_numpy_dtype
)
from .parsing import parse_type_string, interpret_python_type, interpret_numpy_dtype
if isinstance(type_spec, PsAbstractType):
return type_spec
......@@ -56,6 +53,16 @@ def make_type(type_spec: UserTypeSpec) -> PsAbstractType:
raise ValueError(f"{type_spec} is not a valid type specification.")
def make_numeric_type(type_spec: UserTypeSpec) -> PsNumericType:
"""Like `make_type`, but only for numeric types."""
dtype = make_type(type_spec)
if not isinstance(dtype, PsNumericType):
raise ValueError(
f"Given type {type_spec} does not translate to a numeric type."
)
return dtype
Custom = PsCustomType
"""`Custom(name)` matches `PsCustomType(name)`"""
......
......@@ -15,8 +15,8 @@ from pystencils.cpu.cpujit import compile_and_load
def test_pairwise_addition():
idx_type = SInt(64)
u = PsLinearizedArray("u", Fp(64, const=True), (..., ...), (..., ...), index_dtype=idx_type)
v = PsLinearizedArray("v", Fp(64), (..., ...), (..., ...), index_dtype=idx_type)
u = PsLinearizedArray("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type)
v = PsLinearizedArray("v", Fp(64), (...,), (...,), index_dtype=idx_type)
u_data = PsArrayBasePointer("u_data", u)
v_data = PsArrayBasePointer("v_data", v)
......
import pytest
import numpy as np
from pystencils.nbackend.exceptions import PsInternalCompilerError
from pystencils.nbackend.types import *
from pystencils.nbackend.types.quick import *
@pytest.mark.parametrize(
"numpy_type",
[
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.int8,
np.int16,
np.int32,
np.int64,
np.float16,
np.float32,
np.float64,
],
)
def test_numpy_translation(numpy_type):
dtype_obj = np.dtype(numpy_type)
ps_type = make_type(numpy_type)
assert isinstance(ps_type, PsNumericType)
assert ps_type.numpy_dtype == dtype_obj
assert ps_type.itemsize == dtype_obj.itemsize
assert isinstance(ps_type.create_constant(13), numpy_type)
if ps_type.is_int():
with pytest.raises(PsTypeError):
ps_type.create_constant(13.0)
with pytest.raises(PsTypeError):
ps_type.create_constant(1.75)
if ps_type.is_sint():
assert numpy_type(17) == ps_type.create_constant(17)
assert numpy_type(-4) == ps_type.create_constant(-4)
if ps_type.is_uint():
with pytest.raises(PsTypeError):
ps_type.create_constant(-4)
if ps_type.is_float():
assert numpy_type(17.3) == ps_type.create_constant(17.3)
assert numpy_type(-4.2) == ps_type.create_constant(-4.2)
def test_constify():
t = PsCustomType("std::shared_ptr< Custom >")
assert deconstify(t) == t
assert deconstify(constify(t)) == t
s = PsCustomType("Field", const=True)
assert constify(s) == s
def test_struct_types():
t = PsStructType(
[
PsStructType.Member("data", Ptr(Fp(32))),
("size", UInt(32)),
]
)
assert t.anonymous
with pytest.raises(PsInternalCompilerError):
str(t)
import pytest
import sympy as sp
import numpy as np
import pymbolic.primitives as pb
from pystencils import Assignment
from pystencils import Assignment, TypedSymbol
from pystencils.nbackend.ast import PsDeclaration
from pystencils.nbackend.types import constify
from pystencils.nbackend.types import constify, make_numeric_type
from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
from pystencils.nbackend.kernelcreation.options import KernelCreationOptions
from pystencils.nbackend.kernelcreation.context import KernelCreationContext
from pystencils.nbackend.kernelcreation.freeze import FreezeExpressions
from pystencils.nbackend.kernelcreation.typification import Typifier
from pystencils.nbackend.kernelcreation.typification import Typifier, TypificationError
def test_typify_simple():
......@@ -68,3 +69,29 @@ def test_contextual_typing():
pytest.fail(f"Unexpected expression: {expr}")
check(expr.expression)
def test_erronous_typing():
options = KernelCreationOptions(default_dtype=make_numeric_type(np.float64))
ctx = KernelCreationContext(options)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
expr = freeze(2 * x + 3 * y + q - 4)
with pytest.raises(TypificationError):
typify(expr)
asm = Assignment(q, 3 - w)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
asm = Assignment(q, 3 - x)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment