diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index b5c04f1bd029ac8fa62d2efaf3f0983b7cb858b5..b626ffb65522ad25168c2f880f401301d2bc6e03 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -58,7 +58,7 @@ from ..ast.expressions import ( ) from ..constants import PsConstant -from ...types import PsStructType, PsType +from ...types import PsNumericType, PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -469,7 +469,7 @@ class FreezeExpressions: ] return cast(PsCall, args[0]) - def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: + def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr: dtype: PsType match cast_expr.dtype: case DynamicType.NUMERIC_TYPE: @@ -479,7 +479,19 @@ class FreezeExpressions: case other if isinstance(other, PsType): dtype = other - return PsCast(dtype, self.visit_expr(cast_expr.expr)) + arg = self.visit_expr(cast_expr.expr) + if ( + isinstance(arg, PsConstantExpr) + and arg.constant.dtype is None + and isinstance(dtype, PsNumericType) + ): + # As of now, the typifier can not infer the type of a bare constant. + # However, untyped constants may not appear in ASTs from which + # kernel functions are generated. Therefore, we annotate constants + # instead of casting them. + return PsConstantExpr(arg.constant.interpret_as(dtype)) + else: + return PsCast(dtype, arg) def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index fc085e2be99f61204cde92438811a3b4e41c8bf7..ce2d24f985cc91fd3f2e25abcb67220c991b358e 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -594,7 +594,14 @@ class Typifier: tc.apply_dtype(arr_type, expr) case PsCast(dtype, arg): - self.visit_expr(arg, TypeContext()) + arg_tc = TypeContext() + self.visit_expr(arg, arg_tc) + + if arg_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of argument to Cast: {arg}" + ) + tc.apply_dtype(dtype, expr) case _: diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 270c8f44a6b2a64500446e2284866436069cc704..9467bdd8e092b6062641f3ec284b8d30bdf20714 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import ( PsGe, PsCall, PsCast, + PsConstantExpr, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -305,3 +306,6 @@ def test_cast_func(): expr = freeze(CastFunc.as_index(z)) assert expr.structurally_equal(PsCast(ctx.index_dtype, z2)) + + expr = freeze(CastFunc(42, create_type("int16"))) + assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16")))) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 2ebfa2ec8ec8cf542aa55bbccd770ef0a7e9f5d4..4c6a4d6024689b7ff96665c9d92f2fcd06986044 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -14,6 +14,9 @@ from pystencils.backend.ast.structural import ( PsBlock, ) from pystencils.backend.ast.expressions import ( + PsAddressOf, + PsArrayInitList, + PsCast, PsConstantExpr, PsSymbolExpr, PsSubscript, @@ -478,3 +481,19 @@ def test_cfunction(): with pytest.raises(TypificationError): _ = typify(PsCall(threeway, (x, p))) + + +def test_inference_fails(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x = PsExpression.make(PsConstant(42)) + + with pytest.raises(TypificationError): + typify(PsEq(x, x)) + + with pytest.raises(TypificationError): + typify(PsArrayInitList([x])) + + with pytest.raises(TypificationError): + typify(PsCast(ctx.default_dtype, x))