From d1a9c0efb2a07a517abfb3bbbbd336c695605aaf Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Wed, 24 Jul 2024 16:49:06 +0200 Subject: [PATCH] freeze casts of bare constants to typed PsConstantExprs --- .../backend/kernelcreation/freeze.py | 18 +++++++++++++++--- .../backend/kernelcreation/typification.py | 9 ++++++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 25ce28115..3f1ca8ede 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -58,7 +58,7 @@ from ..ast.expressions import ( ) from ..constants import PsConstant -from ...types import PsStructType, PsType +from ...types import PsNumericType, PsStructType, PsType from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions @@ -462,7 +462,7 @@ class FreezeExpressions: ] return cast(PsCall, args[0]) - def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: + def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr: dtype: PsType match cast_expr.dtype: case DynamicType.NUMERIC_TYPE: @@ -472,7 +472,19 @@ class FreezeExpressions: case other if isinstance(other, PsType): dtype = other - return PsCast(dtype, self.visit_expr(cast_expr.expr)) + arg = self.visit_expr(cast_expr.expr) + if ( + isinstance(arg, PsConstantExpr) + and arg.constant.dtype is None + and isinstance(dtype, PsNumericType) + ): + # As of now, the typifier can not infer the type of a bare constant. + # However, untyped constants may not appear in ASTs from which + # kernel functions are generated. Therefore, we annotate constants + # instead of casting them. + return PsConstantExpr(arg.constant.interpret_as(dtype)) + else: + return PsCast(dtype, arg) def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index fc085e2be..ce2d24f98 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -594,7 +594,14 @@ class Typifier: tc.apply_dtype(arr_type, expr) case PsCast(dtype, arg): - self.visit_expr(arg, TypeContext()) + arg_tc = TypeContext() + self.visit_expr(arg, arg_tc) + + if arg_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of argument to Cast: {arg}" + ) + tc.apply_dtype(dtype, expr) case _: -- GitLab