Skip to content
Snippets Groups Projects

Freeze casts of bare constants to typed PsConstantExprs

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/cast-bare-constant into v2.0-dev
4 files
+ 46
4
Compare changes
  • Side-by-side
  • Inline
Files
4
@@ -58,7 +58,7 @@ from ..ast.expressions import (
@@ -58,7 +58,7 @@ from ..ast.expressions import (
)
)
from ..constants import PsConstant
from ..constants import PsConstant
from ...types import PsStructType, PsType
from ...types import PsNumericType, PsStructType, PsType
from ..exceptions import PsInputError
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
from ..functions import PsMathFunction, MathFunctions
@@ -462,7 +462,7 @@ class FreezeExpressions:
@@ -462,7 +462,7 @@ class FreezeExpressions:
]
]
return cast(PsCall, args[0])
return cast(PsCall, args[0])
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast:
def map_CastFunc(self, cast_expr: CastFunc) -> PsCast | PsConstantExpr:
dtype: PsType
dtype: PsType
match cast_expr.dtype:
match cast_expr.dtype:
case DynamicType.NUMERIC_TYPE:
case DynamicType.NUMERIC_TYPE:
@@ -472,7 +472,19 @@ class FreezeExpressions:
@@ -472,7 +472,19 @@ class FreezeExpressions:
case other if isinstance(other, PsType):
case other if isinstance(other, PsType):
dtype = other
dtype = other
return PsCast(dtype, self.visit_expr(cast_expr.expr))
arg = self.visit_expr(cast_expr.expr)
 
if (
 
isinstance(arg, PsConstantExpr)
 
and arg.constant.dtype is None
 
and isinstance(dtype, PsNumericType)
 
):
 
# As of now, the typifier can not infer the type of a bare constant.
 
# However, untyped constants may not appear in ASTs from which
 
# kernel functions are generated. Therefore, we annotate constants
 
# instead of casting them.
 
return PsConstantExpr(arg.constant.interpret_as(dtype))
 
else:
 
return PsCast(dtype, arg)
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel:
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
arg1, arg2 = [self.visit_expr(arg) for arg in rel.args]
Loading