From c0f2fb544fe086d782793508db392e1cfc0b5677 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 2 Apr 2024 10:31:05 +0200
Subject: [PATCH] Add one more test case and doc comments

---
 .../backend/kernelcreation/typification.py    | 27 ++++++++++++++++---
 .../kernelcreation/test_typification.py       | 16 ++++++++---
 2 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index b01b1d35e..6cb42fce0 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -55,6 +55,17 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
 
 
 class TypeContext:
+    """Typing context, with support for type inference and checking.
+    
+    Instances of this class are used to propagate and check data types across expression subtrees
+    of the AST. Each type context has:
+
+     - A target type `target_type`, which shall be applied to all expressions it covers
+     - A set of restrictions on the target type:
+       - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
+       - Additional restrictions may be added in the future.
+    """
+
     def __init__(
         self, target_type: PsType | None = None, require_nonconst: bool = False
     ):
@@ -222,11 +233,11 @@ class Typifier:
 
     **Typing Rules**
 
-    The following rules apply:
+    The following general rules apply:
 
      - The context's `default_dtype` is applied to all untyped symbols
-     - By default, all expressions receive a ``const`` type unless otherwise required
-     - The left-hand side of any non-declaration assignment must not be ``const``
+     - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
+       left-hand side
 
     **Typing of symbol expressions**
 
@@ -304,7 +315,15 @@ class Typifier:
                 raise NotImplementedError(f"Can't typify {node}")
 
     def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
-        """Recursive processing of expression nodes"""
+        """Recursive processing of expression nodes.
+        
+        This method opens, expands, and closes typing contexts according to the respective expression's
+        typing rules. It may add or check restrictions only when opening or closing a type context.
+
+        The actual type inference and checking during context expansion are performed by the methods
+        of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling
+        either ``apply_dtype`` or ``infer_dtype``.
+        """
         match expr:
             case PsSymbolExpr(_):
                 if expr.symbol.dtype is None:
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 5e9d7cbec..d9cc5f9ce 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -6,7 +6,7 @@ from typing import cast
 
 from pystencils import Assignment, TypedSymbol, Field, FieldType
 
-from pystencils.backend.ast.structural import PsDeclaration
+from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression
 from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
 from pystencils.backend.constants import PsConstant
 from pystencils.types import constify
@@ -127,6 +127,17 @@ def test_lhs_constness():
     with pytest.raises(TypificationError):
         _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
 
+    #   Const LHS is only OK in declarations
+
+    q = ctx.get_symbol("q", Fp(32, const=True))
+    ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
+    ast = typify(ast)
+    assert ast.lhs.dtype == Fp(32, const=True)
+
+    ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
+    with pytest.raises(TypificationError):
+        typify(ast)
+
 
 def test_typify_structs():
     ctx = KernelCreationContext(default_dtype=Fp(32))
@@ -278,6 +289,3 @@ def test_typify_constant_clones():
 
     assert expr_clone.operand1.dtype is None
     assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
-
-
-# test_lhs_constness()
-- 
GitLab