diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 66da00c114860aa4be7084311251de7f5c3f04bb..d85c5341073e548301ca74131ad091ef5db35595 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -646,9 +646,8 @@ class Typifier: self.visit_expr(op2, args_tc) if args_tc.target_type is None: - raise TypificationError( - f"Unable to determine type of arguments to relation: {expr}" - ) + args_tc.apply_hint(ToDefault(self._ctx.default_dtype)) + if not isinstance(args_tc.target_type, PsNumericType): raise TypificationError( f"Invalid type in arguments to relation\n" @@ -738,9 +737,7 @@ class Typifier: self.visit_expr(arg, arg_tc) if arg_tc.target_type is None: - raise TypificationError( - f"Unable to determine type of argument to Cast: {arg}" - ) + arg_tc.apply_hint(ToDefault(self._ctx.default_dtype)) tc.apply_dtype(dtype, expr) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 7424977be4b90ea0794d2dd3f2c82aef372acbf7..0c707e33125d05cf5c1261ec68e899090f182bbb 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -6,6 +6,7 @@ from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment +from pystencils.backend.ast import dfs_preorder from pystencils.backend.ast.structural import ( PsDeclaration, PsAssignment, @@ -14,7 +15,6 @@ from pystencils.backend.ast.structural import ( PsBlock, ) from pystencils.backend.ast.expressions import ( - PsAddressOf, PsArrayInitList, PsCast, PsConstantExpr, @@ -32,7 +32,6 @@ from pystencils.backend.ast.expressions import ( PsLt, PsCall, PsTernary, - PsArrayInitList, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction @@ -218,6 +217,37 @@ def test_default_typing(): check(expr) +def test_nested_contexts_defaults(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + + fortytwo = PsExpression.make(PsConstant(42)) + + expr = typify(PsEq(fortytwo, fortytwo)) + assert expr.dtype == Bool(const=True) + assert fortytwo.dtype == Fp(16, const=True) + assert fortytwo.constant.dtype == Fp(16, const=True) + + decl = freeze(Assignment(x, sp.Ge(3, 4, evaluate=False))) + decl = typify(decl) + + assert ctx.get_symbol("x").dtype == Bool() + for c in dfs_preorder(decl.rhs, lambda n: isinstance(n, PsConstantExpr)): + assert c.dtype == Fp(16, const=True) + assert c.constant.dtype == Fp(16, const=True) + + thirtyone = PsExpression.make(PsConstant(31)) + decl = PsDeclaration(freeze(y), PsCast(Fp(32), thirtyone)) + decl = typify(decl) + + assert ctx.get_symbol("y").dtype == Fp(32) + assert thirtyone.dtype == Fp(16, const=True) + assert thirtyone.constant.dtype == Fp(16, const=True) + + def test_constant_decls(): ctx = KernelCreationContext(default_dtype=Fp(16)) freeze = FreezeExpressions(ctx) @@ -571,19 +601,3 @@ def test_cfunction(): with pytest.raises(TypificationError): _ = typify(PsCall(threeway, (x, p))) - - -def test_inference_fails(): - ctx = KernelCreationContext() - typify = Typifier(ctx) - - x = PsExpression.make(PsConstant(42)) - - with pytest.raises(TypificationError): - typify(PsEq(x, x)) - - with pytest.raises(TypificationError): - typify(PsArrayInitList([x])) - - with pytest.raises(TypificationError): - typify(PsCast(ctx.default_dtype, x))