diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 49bc302867c51431210bb392e501e4269baee428..fc085e2be99f61204cde92438811a3b4e41c8bf7 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -68,18 +68,12 @@ class TypeContext: - A set of restrictions on the target type: - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides - Additional restrictions may be added in the future. - - The type context also tracks the tree traversal of the typifier: - - - ``is_lhs`` is set to True while a left-hand side expression is being processed, - and `False` while a right-hand side expression is processed. """ def __init__( self, target_type: PsType | None = None, require_nonconst: bool = False, - is_lhs: bool = False ): self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] @@ -88,8 +82,6 @@ class TypeContext: self._fix_constness(target_type) if target_type is not None else None ) - self._is_lhs = is_lhs - @property def target_type(self) -> PsType | None: return self._target_type @@ -97,14 +89,6 @@ class TypeContext: @property def require_nonconst(self) -> bool: return self._require_nonconst - - @property - def is_lhs(self) -> bool: - return self._is_lhs - - @is_lhs.setter - def is_lhs(self, value: bool): - self._is_lhs = value def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): """Applies the given ``dtype`` to this type context, and optionally to the given expression. @@ -333,29 +317,41 @@ class Typifier: self.visit(s) case PsDeclaration(lhs, rhs): + # Only if the LHS is an untyped symbol, infer its type from the RHS + infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None + tc = TypeContext() - tc.is_lhs = True - self.visit_expr(lhs, tc) - tc.is_lhs = False + if infer_lhs: + tc.infer_dtype(lhs) + else: + self.visit_expr(lhs, tc) + assert tc.target_type is not None self.visit_expr(rhs, tc) - if tc.target_type is None: + if infer_lhs and tc.target_type is None: # no type has been inferred -> use the default dtype tc.apply_dtype(self._ctx.default_dtype) case PsAssignment(lhs, rhs): - tc_lhs = TypeContext(require_nonconst=True, is_lhs=True) - self.visit_expr(lhs, tc_lhs) + infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None + + tc_lhs = TypeContext(require_nonconst=True) + + if infer_lhs: + tc_lhs.infer_dtype(lhs) + else: + self.visit_expr(lhs, tc_lhs) + assert tc_lhs.target_type is not None tc_rhs = TypeContext(target_type=tc_lhs.target_type) self.visit_expr(rhs, tc_rhs) - if tc_rhs.target_type is None: - tc_rhs.apply_dtype(self._ctx.default_dtype) - - if tc_lhs.target_type is None: + if infer_lhs: + if tc_rhs.target_type is None: + tc_rhs.apply_dtype(self._ctx.default_dtype) + assert tc_rhs.target_type is not None tc_lhs.apply_dtype(deconstify(tc_rhs.target_type)) @@ -398,15 +394,9 @@ class Typifier: """ match expr: case PsSymbolExpr(symb): - if tc.is_lhs: - if symb.dtype is not None: - tc.apply_dtype(symb.dtype, expr) - elif tc.is_lhs: - tc.infer_dtype(expr) - else: - if symb.dtype is None: - symb.dtype = self._ctx.default_dtype - tc.apply_dtype(symb.dtype, expr) + if symb.dtype is None: + symb.dtype = self._ctx.default_dtype + tc.apply_dtype(symb.dtype, expr) case PsConstantExpr(c): if c.dtype is not None: diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index ca170401fcc78900e5fd9a6e43ba118dccce2c5a..d3da7e8881d631266d58ee6cc0d4d3612a2900a1 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -16,6 +16,7 @@ from pystencils.backend.ast.structural import ( from pystencils.backend.ast.expressions import ( PsConstantExpr, PsSymbolExpr, + PsSubscript, PsBinOp, PsAnd, PsOr, @@ -27,12 +28,12 @@ from pystencils.backend.ast.expressions import ( PsGt, PsLt, PsCall, - PsTernary + PsTernary, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type -from pystencils.types.quick import Fp, Int, Bool +from pystencils.types.quick import Fp, Int, Bool, Arr from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -242,9 +243,11 @@ def test_lhs_inference(): assert ctx.get_symbol("z").dtype == Fp(16) assert fasm.lhs.dtype == Fp(16) - fasm = PsDeclaration(PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))) + fasm = PsDeclaration( + PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q)) + ) fasm = typify(fasm) - + assert ctx.get_symbol("r").dtype == Bool() assert fasm.lhs.dtype == constify(Bool()) assert fasm.rhs.dtype == constify(Bool()) @@ -282,6 +285,26 @@ def test_erronous_typing(): typify(fasm) +def test_invalid_indices(): + ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) + typify = Typifier(ctx) + + arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64)))) + x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"] + + # Using default-typed symbols as array indices is illegal when the default type is a float + + fasm = PsAssignment(PsSubscript(arr, x + y), z) + + with pytest.raises(TypificationError): + typify(fasm) + + fasm = PsAssignment(z, PsSubscript(arr, x + y)) + + with pytest.raises(TypificationError): + typify(fasm) + + def test_typify_integer_binops(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) @@ -410,7 +433,7 @@ def test_invalid_conditions(): with pytest.raises(TypificationError): typify(cond) - + def test_typify_ternary(): ctx = KernelCreationContext() typify = Typifier(ctx)