Skip to content
Snippets Groups Projects
Commit 03bd4bf4 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fallback to default types in independent nested contexts

parent e9bf1249
No related branches found
No related tags found
1 merge request!418Nesting of Type Contexts, Type Hints, and Improved Array Typing
Pipeline #68016 passed
...@@ -646,9 +646,8 @@ class Typifier: ...@@ -646,9 +646,8 @@ class Typifier:
self.visit_expr(op2, args_tc) self.visit_expr(op2, args_tc)
if args_tc.target_type is None: if args_tc.target_type is None:
raise TypificationError( args_tc.apply_hint(ToDefault(self._ctx.default_dtype))
f"Unable to determine type of arguments to relation: {expr}"
)
if not isinstance(args_tc.target_type, PsNumericType): if not isinstance(args_tc.target_type, PsNumericType):
raise TypificationError( raise TypificationError(
f"Invalid type in arguments to relation\n" f"Invalid type in arguments to relation\n"
...@@ -738,9 +737,7 @@ class Typifier: ...@@ -738,9 +737,7 @@ class Typifier:
self.visit_expr(arg, arg_tc) self.visit_expr(arg, arg_tc)
if arg_tc.target_type is None: if arg_tc.target_type is None:
raise TypificationError( arg_tc.apply_hint(ToDefault(self._ctx.default_dtype))
f"Unable to determine type of argument to Cast: {arg}"
)
tc.apply_dtype(dtype, expr) tc.apply_dtype(dtype, expr)
......
...@@ -6,6 +6,7 @@ from typing import cast ...@@ -6,6 +6,7 @@ from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import ( from pystencils.backend.ast.structural import (
PsDeclaration, PsDeclaration,
PsAssignment, PsAssignment,
...@@ -14,7 +15,6 @@ from pystencils.backend.ast.structural import ( ...@@ -14,7 +15,6 @@ from pystencils.backend.ast.structural import (
PsBlock, PsBlock,
) )
from pystencils.backend.ast.expressions import ( from pystencils.backend.ast.expressions import (
PsAddressOf,
PsArrayInitList, PsArrayInitList,
PsCast, PsCast,
PsConstantExpr, PsConstantExpr,
...@@ -32,7 +32,6 @@ from pystencils.backend.ast.expressions import ( ...@@ -32,7 +32,6 @@ from pystencils.backend.ast.expressions import (
PsLt, PsLt,
PsCall, PsCall,
PsTernary, PsTernary,
PsArrayInitList,
) )
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction from pystencils.backend.functions import CFunction
...@@ -218,6 +217,37 @@ def test_default_typing(): ...@@ -218,6 +217,37 @@ def test_default_typing():
check(expr) check(expr)
def test_nested_contexts_defaults():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
fortytwo = PsExpression.make(PsConstant(42))
expr = typify(PsEq(fortytwo, fortytwo))
assert expr.dtype == Bool(const=True)
assert fortytwo.dtype == Fp(16, const=True)
assert fortytwo.constant.dtype == Fp(16, const=True)
decl = freeze(Assignment(x, sp.Ge(3, 4, evaluate=False)))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Bool()
for c in dfs_preorder(decl.rhs, lambda n: isinstance(n, PsConstantExpr)):
assert c.dtype == Fp(16, const=True)
assert c.constant.dtype == Fp(16, const=True)
thirtyone = PsExpression.make(PsConstant(31))
decl = PsDeclaration(freeze(y), PsCast(Fp(32), thirtyone))
decl = typify(decl)
assert ctx.get_symbol("y").dtype == Fp(32)
assert thirtyone.dtype == Fp(16, const=True)
assert thirtyone.constant.dtype == Fp(16, const=True)
def test_constant_decls(): def test_constant_decls():
ctx = KernelCreationContext(default_dtype=Fp(16)) ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx) freeze = FreezeExpressions(ctx)
...@@ -571,19 +601,3 @@ def test_cfunction(): ...@@ -571,19 +601,3 @@ def test_cfunction():
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
_ = typify(PsCall(threeway, (x, p))) _ = typify(PsCall(threeway, (x, p)))
def test_inference_fails():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x = PsExpression.make(PsConstant(42))
with pytest.raises(TypificationError):
typify(PsEq(x, x))
with pytest.raises(TypificationError):
typify(PsArrayInitList([x]))
with pytest.raises(TypificationError):
typify(PsCast(ctx.default_dtype, x))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment