From c83b06a597376daacffedd6cb78fc96e0709f877 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 26 Jul 2024 14:01:13 +0200
Subject: [PATCH] Introduce DereferencableTo type hint and implement inference
 hooks for (nested) arrays

---
 .../backend/kernelcreation/typification.py    | 47 +++++++++++++---
 .../kernelcreation/test_typification.py       | 56 ++++++++++++++++++-
 2 files changed, 95 insertions(+), 8 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index dd004f5f0..2e81a0779 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -73,9 +73,16 @@ class TypeHint:
 
 @dataclass(frozen=True)
 class ToDefault(TypeHint):
+    """Indicates to fall back to a default type."""
     default_dtype: PsType
 
 
+@dataclass(frozen=True)
+class DereferencableTo(TypeHint):
+    """Indicates that the type has to be dereferencable to the given base type."""
+    base_type: PsType | TypeHint
+
+
 InferenceHook = Callable[[PsType | TypeHint], PsType | None]
 """An inference hook is a callback that is attached to a type context,
 to be called once type information about that context is known.
@@ -83,8 +90,8 @@ The inference hook will then try to use that information to resolve nested type
 and potentially the context it is attached to as well.
 
 When called with a `PsType`, that type is the target type of the context to which the hook is attached.
-The hook has to use this type to resolve any nested type contexts and return `None`;
-if it cannot resolve its nested contexts, it must raise a TypificationError.
+The hook has to use this type to resolve any nested type contexts and must return either `None` or the same data type.
+If it cannot resolve its nested contexts, it must raise a TypificationError.
 
 When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts.
 If it succeeds, it has to return the data type that must be applied to the outer context.
@@ -200,8 +207,9 @@ class TypeContext:
                 self.apply_dtype(default_dtype)
             case _:
                 for i, hook in enumerate(self._inference_hooks):
-                    self._target_type = hook(hint)
-                    if self._target_type is not None:
+                    target_type = hook(hint)
+                    if target_type is not None:
+                        self._target_type = self._fix_constness(target_type)
                         #   That hook was successful; remove it so it is not called a second time
                         del self._inference_hooks[i]
 
@@ -523,12 +531,25 @@ class Typifier:
                 arr_tc = TypeContext()
                 self.visit_expr(arr, arr_tc)
 
-                if not isinstance(arr_tc.target_type, PsDereferencableType):
+                if arr_tc.target_type is None:
+                    def subscript_hook(type_or_hint: PsType | TypeHint) -> PsType | None:
+                        #   Whatever type the enclosing context is to be,
+                        #   the type of `arr` has to be dereferencable to it
+                        arr_tc.apply_hint(DereferencableTo(type_or_hint))
+
+                        #   Now we know the type of the array
+                        #   -> pass its dereferenced version to the outer context
+                        assert isinstance(arr_tc.target_type, PsDereferencableType)
+                        return arr_tc.target_type.base_type
+                    
+                    tc.hook(subscript_hook)
+
+                elif not isinstance(arr_tc.target_type, PsDereferencableType):
                     raise TypificationError(
                         "Type of subscript base is not subscriptable."
                     )
-
-                tc.apply_dtype(arr_tc.target_type.base_type, expr)
+                else:
+                    tc.apply_dtype(arr_tc.target_type.base_type, expr)
 
                 index_tc = TypeContext()
                 self.visit_expr(idx, index_tc)
@@ -683,13 +704,25 @@ class Typifier:
                                     )
                                 items_tc.apply_dtype(deconstify(elem_type))
                                 tc.infer_dtype(expr)
+
+                            case DereferencableTo(elem_type_or_hint):
+                                if isinstance(elem_type_or_hint, PsType):
+                                    items_tc.apply_dtype(deconstify(elem_type_or_hint))
+                                else:
+                                    items_tc.apply_hint(elem_type_or_hint)
+                                tc.infer_dtype(expr)
+
+                                return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
+
                             case ToDefault():
                                 items_tc.apply_hint(type_or_hint)
                                 tc.infer_dtype(expr)
                                 return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
+    
                             case TypeHint():
                                 #   Can't deal with any other type hints
                                 return None
+                            
                             case other_type:
                                 raise TypificationError(
                                     f"Cannot apply type {other_type} to array initializer {expr}."
diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py
index 44e3b8fa6..6592f22be 100644
--- a/tests/nbackend/kernelcreation/test_typification.py
+++ b/tests/nbackend/kernelcreation/test_typification.py
@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import (
     PsLt,
     PsCall,
     PsTernary,
+    PsArrayInitList,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import CFunction
@@ -220,7 +221,7 @@ def test_constant_decls():
     typify = Typifier(ctx)
 
     x, y = sp.symbols("x, y")
-    
+
     decl = freeze(Assignment(x, 3.0))
     decl = typify(decl)
     assert ctx.get_symbol("x").dtype == Fp(16)
@@ -250,6 +251,59 @@ def test_constant_array_decls():
     assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
 
 
+def test_inline_arrays_1d():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y = sp.symbols("x, y")
+    idx = TypedSymbol("idx", Int(32))
+
+    arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
+    decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, freeze(idx)))
+    #   The array elements should learn their type from the context, which gets it from `y`
+
+    decl = typify(decl)
+    assert decl.lhs.dtype == Fp(16, const=True)
+    assert decl.rhs.dtype == Fp(16, const=True)
+
+    assert arr.dtype == Arr(Fp(16), 4, const=True)
+    for item in arr.items:
+        assert item.dtype == Fp(16, const=True)
+
+
+def test_inline_arrays_3d():
+    ctx = KernelCreationContext(default_dtype=Fp(16))
+    freeze = FreezeExpressions(ctx)
+    typify = Typifier(ctx)
+
+    x, y = sp.symbols("x, y")
+    idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
+
+    arr = freeze(sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))))
+    decl = PsDeclaration(
+        freeze(x),
+        freeze(y)
+        + PsSubscript(
+            PsSubscript(PsSubscript(arr, freeze(idx[0])), freeze(idx[1])),
+            freeze(idx[2]),
+        ),
+    )
+    #   The array elements should learn their type from the context, which gets it from `y`
+
+    decl = typify(decl)
+    assert decl.lhs.dtype == Fp(16, const=True)
+    assert decl.rhs.dtype == Fp(16, const=True)
+
+    assert arr.dtype == Arr(Arr(Arr(Fp(16), 2), 3), 2, const=True)
+    for item in arr.items:
+        assert item.dtype == Arr(Arr(Fp(16), 2), 3, const=True)
+        for iitem in item.items:
+            assert iitem.dtype == Arr(Fp(16), 2, const=True)
+            for iiitem in iitem.items:
+                assert iiitem.dtype == Fp(16, const=True)
+
+
 def test_lhs_inference():
     ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
     freeze = FreezeExpressions(ctx)
-- 
GitLab