From b305ce9a2bf6b4526cca491e113d4129029202cc Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 18 Jul 2024 14:42:50 +0200 Subject: [PATCH] Change symbol typificaiton: Infer unknown LHS data types from RHS --- .../backend/kernelcreation/context.py | 6 +- .../backend/kernelcreation/typification.py | 77 +++++++++++++++---- .../kernelcreation/test_typification.py | 48 +++++++++++- .../test_conditional_field_access.py | 3 - 4 files changed, 111 insertions(+), 23 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 9df186470..916274314 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -12,7 +12,7 @@ from ...sympyextensions.typed_sympy import TypedSymbol from ..symbols import PsSymbol from ..arrays import PsLinearizedArray -from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType +from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify from ..constraints import KernelParamsConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError @@ -63,8 +63,8 @@ class KernelCreationContext: default_dtype: PsNumericType = DEFAULTS.numeric_dtype, index_dtype: PsIntegerType = DEFAULTS.index_dtype, ): - self._default_dtype = default_dtype - self._index_dtype = index_dtype + self._default_dtype = deconstify(default_dtype) + self._index_dtype = deconstify(index_dtype) self._symbols: dict[str, PsSymbol] = dict() diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 8ef6edd24..49bc30286 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -68,10 +68,18 @@ class TypeContext: - A set of restrictions on the target type: - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides - Additional restrictions may be added in the future. + + The type context also tracks the tree traversal of the typifier: + + - ``is_lhs`` is set to True while a left-hand side expression is being processed, + and `False` while a right-hand side expression is processed. """ def __init__( - self, target_type: PsType | None = None, require_nonconst: bool = False + self, + target_type: PsType | None = None, + require_nonconst: bool = False, + is_lhs: bool = False ): self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] @@ -80,6 +88,8 @@ class TypeContext: self._fix_constness(target_type) if target_type is not None else None ) + self._is_lhs = is_lhs + @property def target_type(self) -> PsType | None: return self._target_type @@ -87,6 +97,14 @@ class TypeContext: @property def require_nonconst(self) -> bool: return self._require_nonconst + + @property + def is_lhs(self) -> bool: + return self._is_lhs + + @is_lhs.setter + def is_lhs(self, value: bool): + self._is_lhs = value def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): """Applies the given ``dtype`` to this type context, and optionally to the given expression. @@ -171,8 +189,10 @@ class TypeContext: ) case PsSymbolExpr(symb): - assert symb.dtype is not None - if not self._compatible(symb.dtype): + if symb.dtype is None: + # Symbols are not forced to constness + symb.dtype = deconstify(self._target_type) + elif not self._compatible(symb.dtype): raise TypificationError( f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" f" Symbol type: {symb.dtype}\n" @@ -262,7 +282,11 @@ class Typifier: The following general rules apply: - - The context's `default_dtype` is applied to all untyped symbols + - The context's ``default_dtype`` is applied to all untyped symbols encountered inside a right-hand side expression + - If an untyped symbol is encountered on an assignment's left-hand side, it will first be attempted to infer its + type from the right-hand side. If that fails, the context's ``default_dtype`` will be applied. + - It is an error if an untyped symbol occurs in the same type context as a typed symbol or constant + with a non-default data type. - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's left-hand side @@ -280,7 +304,12 @@ class Typifier: def __call__(self, node: NodeT) -> NodeT: if isinstance(node, PsExpression): - self.visit_expr(node, TypeContext()) + tc = TypeContext() + self.visit_expr(node, tc) + + if tc.target_type is None: + # no type could be inferred -> take the default + tc.apply_dtype(self._ctx.default_dtype) else: self.visit(node) return node @@ -305,19 +334,31 @@ class Typifier: case PsDeclaration(lhs, rhs): tc = TypeContext() - # LHS defines target type; type context carries it to RHS + + tc.is_lhs = True self.visit_expr(lhs, tc) - assert tc.target_type is not None + tc.is_lhs = False + self.visit_expr(rhs, tc) + if tc.target_type is None: + # no type has been inferred -> use the default dtype + tc.apply_dtype(self._ctx.default_dtype) + case PsAssignment(lhs, rhs): - tc_lhs = TypeContext(require_nonconst=True) + tc_lhs = TypeContext(require_nonconst=True, is_lhs=True) self.visit_expr(lhs, tc_lhs) - assert tc_lhs.target_type is not None - tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False) + tc_rhs = TypeContext(target_type=tc_lhs.target_type) self.visit_expr(rhs, tc_rhs) + if tc_rhs.target_type is None: + tc_rhs.apply_dtype(self._ctx.default_dtype) + + if tc_lhs.target_type is None: + assert tc_rhs.target_type is not None + tc_lhs.apply_dtype(deconstify(tc_rhs.target_type)) + case PsConditional(cond, branch_true, branch_false): cond_tc = TypeContext(PsBoolType()) self.visit_expr(cond, cond_tc) @@ -330,6 +371,7 @@ class Typifier: case PsLoop(ctr, start, stop, step, body): if ctr.symbol.dtype is None: ctr.symbol.apply_dtype(self._ctx.index_dtype) + ctr.dtype = ctr.symbol.get_dtype() tc_index = TypeContext(ctr.symbol.dtype) self.visit_expr(start, tc_index) @@ -355,11 +397,16 @@ class Typifier: either ``apply_dtype`` or ``infer_dtype``. """ match expr: - case PsSymbolExpr(_): - if expr.symbol.dtype is None: - expr.symbol.dtype = self._ctx.default_dtype - - tc.apply_dtype(expr.symbol.dtype, expr) + case PsSymbolExpr(symb): + if tc.is_lhs: + if symb.dtype is not None: + tc.apply_dtype(symb.dtype, expr) + elif tc.is_lhs: + tc.infer_dtype(expr) + else: + if symb.dtype is None: + symb.dtype = self._ctx.default_dtype + tc.apply_dtype(symb.dtype, expr) case PsConstantExpr(c): if c.dtype is not None: diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 5c2631e1e..ca170401f 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -4,7 +4,7 @@ import numpy as np from typing import cast -from pystencils import Assignment, TypedSymbol, Field, FieldType +from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment from pystencils.backend.ast.structural import ( PsDeclaration, @@ -186,7 +186,7 @@ def test_typify_structs(): fasm = typify(fasm) -def test_contextual_typing(): +def test_default_typing(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) @@ -213,6 +213,43 @@ def test_contextual_typing(): check(expr) +def test_lhs_inference(): + ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y, z = sp.symbols("x, y, z") + q = TypedSymbol("q", np.float32) + w = TypedSymbol("w", np.float16) + + # Type of the LHS is propagated to untyped RHS symbols + + asm = Assignment(x, 3 - q) + fasm = typify(freeze(asm)) + + assert ctx.get_symbol("x").dtype == Fp(32) + assert fasm.lhs.dtype == constify(Fp(32)) + + asm = Assignment(y, 3 - w) + fasm = typify(freeze(asm)) + + assert ctx.get_symbol("y").dtype == Fp(16) + assert fasm.lhs.dtype == constify(Fp(16)) + + fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w)) + fasm = typify(fasm) + + assert ctx.get_symbol("z").dtype == Fp(16) + assert fasm.lhs.dtype == Fp(16) + + fasm = PsDeclaration(PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))) + fasm = typify(fasm) + + assert ctx.get_symbol("r").dtype == Bool() + assert fasm.lhs.dtype == constify(Bool()) + assert fasm.rhs.dtype == constify(Bool()) + + def test_erronous_typing(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) @@ -227,16 +264,23 @@ def test_erronous_typing(): with pytest.raises(TypificationError): typify(expr) + # Conflict between LHS and RHS symbols asm = Assignment(q, 3 - w) fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) + # Do not propagate types back from LHS symbols to RHS symbols asm = Assignment(q, 3 - x) fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) + asm = AddAugmentedAssignment(z, 3 - q) + fasm = freeze(asm) + with pytest.raises(TypificationError): + typify(fasm) + def test_typify_integer_binops(): ctx = KernelCreationContext() diff --git a/tests/symbolics/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py index e18ffc56a..1dbc88cf4 100644 --- a/tests/symbolics/test_conditional_field_access.py +++ b/tests/symbolics/test_conditional_field_access.py @@ -51,9 +51,6 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): @pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('with_cse', (False, 'with_cse')) def test_boundary_check(dtype, with_cse): - if with_cse: - pytest.xfail("Doesn't typify correctly yet.") - f, g = ps.fields(f"f, g : {dtype}[2D]") stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) -- GitLab