From f1486c07f7c2500baf21f705cab71d1e1ba2ddaf Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sun, 31 Mar 2024 15:31:15 +0200 Subject: [PATCH] Typing fixes: - Symbols now also receive the const-qualification of `default_dtype` - Fix const propagation to structs - Add doc comment explaining all this --- .../backend/kernelcreation/typification.py | 84 ++++++++++++------- .../kernelcreation/test_typification.py | 12 +-- 2 files changed, 62 insertions(+), 34 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 0ea0e9d51..b01b1d35e 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -65,13 +65,24 @@ class TypeContext: self._fix_constness(target_type) if target_type is not None else None ) - def apply_dtype(self, expr: PsExpression | None, dtype: PsType): - """Applies the given ``dtype`` to the given expression inside this type context. + @property + def target_type(self) -> PsType | None: + return self._target_type + + @property + def require_nonconst(self) -> bool: + return self._require_nonconst + + def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): + """Applies the given ``dtype`` to this type context, and optionally to the given expression. - The given expression will be covered by this type context. If the context's target_type is already known, it must be compatible with the given dtype. If the target type is still unknown, target_type is set to dtype and retroactively applied to all deferred expressions. + + If an expression is specified, it will be covered by the type context. + If the expression already has a data type set, it must be compatible with the target type + and will be replaced by it. """ dtype = self._fix_constness(dtype) @@ -96,7 +107,8 @@ class TypeContext: Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is called on this context. - If the expression already has a data type set, it must be equal to the inferred type. + If the expression already has a data type set, it must be compatible with the target type + and will be replaced by it. """ if self._target_type is None: @@ -136,9 +148,8 @@ class TypeContext: ) case PsSymbolExpr(symb): - if symb.dtype is None: - symb.dtype = self._target_type - elif not self._compatible(symb.dtype): + assert symb.dtype is not None + if not self._compatible(symb.dtype): raise TypificationError( f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" f" Symbol type: {symb.dtype}\n" @@ -183,14 +194,12 @@ class TypeContext: else: return constify(dtype) - @property - def target_type(self) -> PsType | None: - return self._target_type - class Typifier: """Apply data types to expressions. + **Contextual Typing** + The Typifier will traverse the AST and apply a contextual typing scheme to figure out the data types of all encountered expressions. To this end, it covers each expression tree with a set of disjoint typing contexts. @@ -211,6 +220,21 @@ class Typifier: in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon as the target type is fixed. + **Typing Rules** + + The following rules apply: + + - The context's `default_dtype` is applied to all untyped symbols + - By default, all expressions receive a ``const`` type unless otherwise required + - The left-hand side of any non-declaration assignment must not be ``const`` + + **Typing of symbol expressions** + + Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but + not necessarily their const-qualification. + A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, + and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`, + but not vice versa. """ def __init__(self, ctx: KernelCreationContext): @@ -284,20 +308,23 @@ class Typifier: match expr: case PsSymbolExpr(_): if expr.symbol.dtype is None: - tc.apply_dtype(expr, self._ctx.default_dtype) - else: - tc.apply_dtype(expr, expr.symbol.dtype) + expr.symbol.dtype = self._ctx.default_dtype - case PsConstantExpr(_): - tc.infer_dtype(expr) + tc.apply_dtype(expr.symbol.dtype, expr) + + case PsConstantExpr(c): + if c.dtype is not None: + tc.apply_dtype(c.dtype, expr) + else: + tc.infer_dtype(expr) case PsArrayAccess(bptr, idx): - tc.apply_dtype(expr, bptr.array.element_type) + tc.apply_dtype(bptr.array.element_type, expr) index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: - index_tc.apply_dtype(idx, self._ctx.index_dtype) + index_tc.apply_dtype(self._ctx.index_dtype, idx) elif not isinstance(index_tc.target_type, PsIntegerType): raise TypificationError( f"Array index is not of integer type: {idx} has type {index_tc.target_type}" @@ -312,12 +339,12 @@ class Typifier: "Type of subscript base is not subscriptable." ) - tc.apply_dtype(expr, arr_tc.target_type.base_type) + tc.apply_dtype(arr_tc.target_type.base_type, expr) index_tc = TypeContext() self.visit_expr(idx, index_tc) if index_tc.target_type is None: - index_tc.apply_dtype(idx, self._ctx.index_dtype) + index_tc.apply_dtype(self._ctx.index_dtype, idx) elif not isinstance(index_tc.target_type, PsIntegerType): raise TypificationError( f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" @@ -332,7 +359,7 @@ class Typifier: "Type of argument to a Deref is not dereferencable" ) - tc.apply_dtype(expr, ptr_tc.target_type.base_type) + tc.apply_dtype(ptr_tc.target_type.base_type, expr) case PsAddressOf(arg): arg_tc = TypeContext() @@ -344,10 +371,11 @@ class Typifier: ) ptr_type = PsPointerType(arg_tc.target_type, True) - tc.apply_dtype(expr, ptr_type) + tc.apply_dtype(ptr_type, expr) case PsLookup(aggr, member_name): - aggr_tc = TypeContext(None) + # Members of a struct type inherit the struct type's `const` qualifier + aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst) self.visit_expr(aggr, aggr_tc) aggr_type = aggr_tc.target_type @@ -361,12 +389,12 @@ class Typifier: raise TypificationError( f"Aggregate of type {aggr_type} does not have a member {member}." ) - + member_type = member.dtype if aggr_type.const: member_type = constify(member_type) - tc.apply_dtype(expr, member_type) + tc.apply_dtype(member_type, expr) case PsBinOp(op1, op2): self.visit_expr(op1, tc) @@ -405,14 +433,14 @@ class Typifier: f"{len(items)} items as {tc.target_type}" ) else: - items_tc.apply_dtype(None, tc.target_type.base_type) + items_tc.apply_dtype(tc.target_type.base_type) else: arr_type = PsArrayType(items_tc.target_type, len(items)) - tc.apply_dtype(expr, arr_type) + tc.apply_dtype(arr_type, expr) case PsCast(dtype, arg): self.visit_expr(arg, TypeContext()) - tc.apply_dtype(expr, dtype) + tc.apply_dtype(dtype, expr) case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 6a1a02860..5e9d7cbec 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -45,7 +45,7 @@ def test_typify_simple(): assert cs.dtype == constify(ctx.default_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" - assert symb.dtype == constify(ctx.default_dtype) + assert symb.dtype == ctx.default_dtype case PsBinOp(op1, op2): check(op1) check(op2) @@ -109,7 +109,6 @@ def test_lhs_constness(): ) x, y, z = sp.symbols("x, y, z") - q = TypedSymbol("q", Fp(32, const=True)) # Assignment RHS may not be const asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) @@ -119,9 +118,6 @@ def test_lhs_constness(): with pytest.raises(TypificationError): _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y))) - # with pytest.raises(TypificationError): - # _ = typify(freeze(Assignment(q, x - z))) - np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) struct_type = constify(create_type(np_struct)) struct_field = Field.create_generic( @@ -146,6 +142,10 @@ def test_typify_structs(): fasm = freeze(asm) fasm = typify(fasm) + asm = Assignment(f.absolute_access((0,), "data"), x) + fasm = freeze(asm) + fasm = typify(fasm) + # Bad asm = Assignment(x, f.absolute_access((0,), "size")) fasm = freeze(asm) @@ -170,7 +170,7 @@ def test_contextual_typing(): assert cs.dtype == constify(ctx.default_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" - assert symb.dtype == constify(ctx.default_dtype) + assert symb.dtype == ctx.default_dtype case PsBinOp(op1, op2): check(op1) check(op2) -- GitLab