Skip to content
Snippets Groups Projects
Commit d1a9c0ef authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

freeze casts of bare constants to typed PsConstantExprs

parent 0c86a9b6
No related branches found
No related tags found
1 merge request!411Freeze casts of bare constants to typed PsConstantExprs
...@@ -58,7 +58,7 @@ from ..ast.expressions import ( ...@@ -58,7 +58,7 @@ from ..ast.expressions import (
) )
from ..constants import PsConstant from ..constants import PsConstant
from ...types import PsStructType, PsType from ...types import PsNumericType, PsStructType, PsType
from ..exceptions import PsInputError from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions from ..functions import PsMathFunction, MathFunctions
...@@ -462,7 +462,7 @@ class FreezeExpressions: ...@@ -462,7 +462,7 @@ class FreezeExpressions:
] ]
return cast(PsCall, args[0]) return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr:
dtype: PsType dtype: PsType
match cast_expr.dtype: match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE: case DynamicType.NUMERIC_TYPE:
...@@ -472,7 +472,19 @@ class FreezeExpressions: ...@@ -472,7 +472,19 @@ class FreezeExpressions:
case other if isinstance(other, PsType): case other if isinstance(other, PsType):
dtype = other dtype = other
return PsCast(dtype, self.visit_expr(cast_expr.expr)) arg = self.visit_expr(cast_expr.expr)
if (
isinstance(arg, PsConstantExpr)
and arg.constant.dtype is None
and isinstance(dtype, PsNumericType)
):
# As of now, the typifier can not infer the type of a bare constant.
# However, untyped constants may not appear in ASTs from which
# kernel functions are generated. Therefore, we annotate constants
# instead of casting them.
return PsConstantExpr(arg.constant.interpret_as(dtype))
else:
return PsCast(dtype, arg)
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
......
...@@ -594,7 +594,14 @@ class Typifier: ...@@ -594,7 +594,14 @@ class Typifier:
tc.apply_dtype(arr_type, expr) tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg): case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext()) arg_tc = TypeContext()
self.visit_expr(arg, arg_tc)
if arg_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to Cast: {arg}"
)
tc.apply_dtype(dtype, expr) tc.apply_dtype(dtype, expr)
case _: case _:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment