From f1486c07f7c2500baf21f705cab71d1e1ba2ddaf Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 31 Mar 2024 15:31:15 +0200
Subject: [PATCH] Typing fixes:  - Symbols now also receive the
 const-qualification of `default_dtype`  - Fix const propagation to structs  -
 Add doc comment explaining all this

---
 .../backend/kernelcreation/typification.py    | 84 ++++++++++++-------
 .../kernelcreation/test_typification.py       | 12 +--
 2 files changed, 62 insertions(+), 34 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 0ea0e9d51..b01b1d35e 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -65,13 +65,24 @@ class TypeContext:
             self._fix_constness(target_type) if target_type is not None else None
         )
 
-    def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
-        """Applies the given ``dtype`` to the given expression inside this type context.
+    @property
+    def target_type(self) -> PsType | None:
+        return self._target_type
+
+    @property
+    def require_nonconst(self) -> bool:
+        return self._require_nonconst
+
+    def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
+        """Applies the given ``dtype`` to this type context, and optionally to the given expression.
 
-        The given expression will be covered by this type context.
         If the context's target_type is already known, it must be compatible with the given dtype.
         If the target type is still unknown, target_type is set to dtype and retroactively applied
         to all deferred expressions.
+
+        If an expression is specified, it will be covered by the type context.
+        If the expression already has a data type set, it must be compatible with the target type
+        and will be replaced by it.
         """
 
         dtype = self._fix_constness(dtype)
@@ -96,7 +107,8 @@ class TypeContext:
         Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
         called on this context.
 
-        If the expression already has a data type set, it must be equal to the inferred type.
+        If the expression already has a data type set, it must be compatible with the target type
+        and will be replaced by it.
         """
 
         if self._target_type is None:
@@ -136,9 +148,8 @@ class TypeContext:
                         )
 
                 case PsSymbolExpr(symb):
-                    if symb.dtype is None:
-                        symb.dtype = self._target_type
-                    elif not self._compatible(symb.dtype):
+                    assert symb.dtype is not None
+                    if not self._compatible(symb.dtype):
                         raise TypificationError(
                             f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
                             f"    Symbol type: {symb.dtype}\n"
@@ -183,14 +194,12 @@ class TypeContext:
         else:
             return constify(dtype)
 
-    @property
-    def target_type(self) -> PsType | None:
-        return self._target_type
-
 
 class Typifier:
     """Apply data types to expressions.
 
+    **Contextual Typing**
+
     The Typifier will traverse the AST and apply a contextual typing scheme to figure out
     the data types of all encountered expressions.
     To this end, it covers each expression tree with a set of disjoint typing contexts.
@@ -211,6 +220,21 @@ class Typifier:
     in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
     as the target type is fixed.
 
+    **Typing Rules**
+
+    The following rules apply:
+
+     - The context's `default_dtype` is applied to all untyped symbols
+     - By default, all expressions receive a ``const`` type unless otherwise required
+     - The left-hand side of any non-declaration assignment must not be ``const``
+
+    **Typing of symbol expressions**
+
+    Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
+    not necessarily their const-qualification.
+    A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type, 
+    and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
+    but not vice versa.
     """
 
     def __init__(self, ctx: KernelCreationContext):
@@ -284,20 +308,23 @@ class Typifier:
         match expr:
             case PsSymbolExpr(_):
                 if expr.symbol.dtype is None:
-                    tc.apply_dtype(expr, self._ctx.default_dtype)
-                else:
-                    tc.apply_dtype(expr, expr.symbol.dtype)
+                    expr.symbol.dtype = self._ctx.default_dtype
 
-            case PsConstantExpr(_):
-                tc.infer_dtype(expr)
+                tc.apply_dtype(expr.symbol.dtype, expr)
+
+            case PsConstantExpr(c):
+                if c.dtype is not None:
+                    tc.apply_dtype(c.dtype, expr)
+                else:
+                    tc.infer_dtype(expr)
 
             case PsArrayAccess(bptr, idx):
-                tc.apply_dtype(expr, bptr.array.element_type)
+                tc.apply_dtype(bptr.array.element_type, expr)
 
                 index_tc = TypeContext()
                 self.visit_expr(idx, index_tc)
                 if index_tc.target_type is None:
-                    index_tc.apply_dtype(idx, self._ctx.index_dtype)
+                    index_tc.apply_dtype(self._ctx.index_dtype, idx)
                 elif not isinstance(index_tc.target_type, PsIntegerType):
                     raise TypificationError(
                         f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
@@ -312,12 +339,12 @@ class Typifier:
                         "Type of subscript base is not subscriptable."
                     )
 
-                tc.apply_dtype(expr, arr_tc.target_type.base_type)
+                tc.apply_dtype(arr_tc.target_type.base_type, expr)
 
                 index_tc = TypeContext()
                 self.visit_expr(idx, index_tc)
                 if index_tc.target_type is None:
-                    index_tc.apply_dtype(idx, self._ctx.index_dtype)
+                    index_tc.apply_dtype(self._ctx.index_dtype, idx)
                 elif not isinstance(index_tc.target_type, PsIntegerType):
                     raise TypificationError(
                         f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
@@ -332,7 +359,7 @@ class Typifier:
                         "Type of argument to a Deref is not dereferencable"
                     )
 
-                tc.apply_dtype(expr, ptr_tc.target_type.base_type)
+                tc.apply_dtype(ptr_tc.target_type.base_type, expr)
 
             case PsAddressOf(arg):
                 arg_tc = TypeContext()
@@ -344,10 +371,11 @@ class Typifier:
                     )
 
                 ptr_type = PsPointerType(arg_tc.target_type, True)
-                tc.apply_dtype(expr, ptr_type)
+                tc.apply_dtype(ptr_type, expr)
 
             case PsLookup(aggr, member_name):
-                aggr_tc = TypeContext(None)
+                #   Members of a struct type inherit the struct type's `const` qualifier
+                aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
                 self.visit_expr(aggr, aggr_tc)
                 aggr_type = aggr_tc.target_type
 
@@ -361,12 +389,12 @@ class Typifier:
                     raise TypificationError(
                         f"Aggregate of type {aggr_type} does not have a member {member}."
                     )
-                
+
                 member_type = member.dtype
                 if aggr_type.const:
                     member_type = constify(member_type)
 
-                tc.apply_dtype(expr, member_type)
+                tc.apply_dtype(member_type, expr)
 
             case PsBinOp(op1, op2):
                 self.visit_expr(op1, tc)
@@ -405,14 +433,14 @@ class Typifier:
                             f"{len(items)} items as {tc.target_type}"
                         )
                     else:
-                        items_tc.apply_dtype(None, tc.target_type.base_type)
+                        items_tc.apply_dtype(tc.target_type.base_type)
                 else:
                     arr_type = PsArrayType(items_tc.target_type, len(items))
-                    tc.apply_dtype(expr, arr_type)
+                    tc.apply_dtype(arr_type, expr)
 
             case PsCast(dtype, arg):
                 self.visit_expr(arg, TypeContext())
-                tc.apply_dtype(expr, dtype)
+                tc.apply_dtype(dtype, expr)
 
             case _:
                 raise NotImplementedError(f"Can't typify {expr}")
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 6a1a02860..5e9d7cbec 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -45,7 +45,7 @@ def test_typify_simple():
                 assert cs.dtype == constify(ctx.default_dtype)
             case PsSymbolExpr(symb):
                 assert symb.name in "xyz"
-                assert symb.dtype == constify(ctx.default_dtype)
+                assert symb.dtype == ctx.default_dtype
             case PsBinOp(op1, op2):
                 check(op1)
                 check(op2)
@@ -109,7 +109,6 @@ def test_lhs_constness():
     )
 
     x, y, z = sp.symbols("x, y, z")
-    q = TypedSymbol("q", Fp(32, const=True))
 
     #   Assignment RHS may not be const
     asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
@@ -119,9 +118,6 @@ def test_lhs_constness():
     with pytest.raises(TypificationError):
         _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
 
-    # with pytest.raises(TypificationError):
-    #     _ = typify(freeze(Assignment(q, x - z)))
-
     np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
     struct_type = constify(create_type(np_struct))
     struct_field = Field.create_generic(
@@ -146,6 +142,10 @@ def test_typify_structs():
     fasm = freeze(asm)
     fasm = typify(fasm)
 
+    asm = Assignment(f.absolute_access((0,), "data"), x)
+    fasm = freeze(asm)
+    fasm = typify(fasm)
+
     #   Bad
     asm = Assignment(x, f.absolute_access((0,), "size"))
     fasm = freeze(asm)
@@ -170,7 +170,7 @@ def test_contextual_typing():
                 assert cs.dtype == constify(ctx.default_dtype)
             case PsSymbolExpr(symb):
                 assert symb.name in "xyz"
-                assert symb.dtype == constify(ctx.default_dtype)
+                assert symb.dtype == ctx.default_dtype
             case PsBinOp(op1, op2):
                 check(op1)
                 check(op2)
-- 
GitLab