Skip to content
Snippets Groups Projects
Commit e9bf1249 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'v2.0-dev' into fhennig/nested-type-context

parents bd6b63ab 05aa74d2
No related branches found
No related tags found
1 merge request!418Nesting of Type Contexts, Type Hints, and Improved Array Typing
Pipeline #68015 failed
......@@ -58,7 +58,7 @@ from ..ast.expressions import (
)
from ..constants import PsConstant
from ...types import PsStructType, PsType
from ...types import PsNumericType, PsStructType, PsType
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
......@@ -469,7 +469,7 @@ class FreezeExpressions:
]
return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast:
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr:
dtype: PsType
match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE:
......@@ -479,7 +479,19 @@ class FreezeExpressions:
case other if isinstance(other, PsType):
dtype = other
return PsCast(dtype, self.visit_expr(cast_expr.expr))
arg = self.visit_expr(cast_expr.expr)
if (
isinstance(arg, PsConstantExpr)
and arg.constant.dtype is None
and isinstance(dtype, PsNumericType)
):
# As of now, the typifier can not infer the type of a bare constant.
# However, untyped constants may not appear in ASTs from which
# kernel functions are generated. Therefore, we annotate constants
# instead of casting them.
return PsConstantExpr(arg.constant.interpret_as(dtype))
else:
return PsCast(dtype, arg)
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
......
......@@ -734,7 +734,14 @@ class Typifier:
tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext())
arg_tc = TypeContext()
self.visit_expr(arg, arg_tc)
if arg_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to Cast: {arg}"
)
tc.apply_dtype(dtype, expr)
case _:
......
......@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import (
PsGe,
PsCall,
PsCast,
PsConstantExpr,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import PsMathFunction, MathFunctions
......@@ -305,3 +306,6 @@ def test_cast_func():
expr = freeze(CastFunc.as_index(z))
assert expr.structurally_equal(PsCast(ctx.index_dtype, z2))
expr = freeze(CastFunc(42, create_type("int16")))
assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16"))))
......@@ -14,6 +14,9 @@ from pystencils.backend.ast.structural import (
PsBlock,
)
from pystencils.backend.ast.expressions import (
PsAddressOf,
PsArrayInitList,
PsCast,
PsConstantExpr,
PsSymbolExpr,
PsSubscript,
......@@ -568,3 +571,19 @@ def test_cfunction():
with pytest.raises(TypificationError):
_ = typify(PsCall(threeway, (x, p)))
def test_inference_fails():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x = PsExpression.make(PsConstant(42))
with pytest.raises(TypificationError):
typify(PsEq(x, x))
with pytest.raises(TypificationError):
typify(PsArrayInitList([x]))
with pytest.raises(TypificationError):
typify(PsCast(ctx.default_dtype, x))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment