From c83b06a597376daacffedd6cb78fc96e0709f877 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 26 Jul 2024 14:01:13 +0200 Subject: [PATCH] Introduce DereferencableTo type hint and implement inference hooks for (nested) arrays --- .../backend/kernelcreation/typification.py | 47 +++++++++++++--- .../kernelcreation/test_typification.py | 56 ++++++++++++++++++- 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index dd004f5f0..2e81a0779 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -73,9 +73,16 @@ class TypeHint: @dataclass(frozen=True) class ToDefault(TypeHint): + """Indicates to fall back to a default type.""" default_dtype: PsType +@dataclass(frozen=True) +class DereferencableTo(TypeHint): + """Indicates that the type has to be dereferencable to the given base type.""" + base_type: PsType | TypeHint + + InferenceHook = Callable[[PsType | TypeHint], PsType | None] """An inference hook is a callback that is attached to a type context, to be called once type information about that context is known. @@ -83,8 +90,8 @@ The inference hook will then try to use that information to resolve nested type and potentially the context it is attached to as well. When called with a `PsType`, that type is the target type of the context to which the hook is attached. -The hook has to use this type to resolve any nested type contexts and return `None`; -if it cannot resolve its nested contexts, it must raise a TypificationError. +The hook has to use this type to resolve any nested type contexts and must return either `None` or the same data type. +If it cannot resolve its nested contexts, it must raise a TypificationError. When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts. If it succeeds, it has to return the data type that must be applied to the outer context. @@ -200,8 +207,9 @@ class TypeContext: self.apply_dtype(default_dtype) case _: for i, hook in enumerate(self._inference_hooks): - self._target_type = hook(hint) - if self._target_type is not None: + target_type = hook(hint) + if target_type is not None: + self._target_type = self._fix_constness(target_type) # That hook was successful; remove it so it is not called a second time del self._inference_hooks[i] @@ -523,12 +531,25 @@ class Typifier: arr_tc = TypeContext() self.visit_expr(arr, arr_tc) - if not isinstance(arr_tc.target_type, PsDereferencableType): + if arr_tc.target_type is None: + def subscript_hook(type_or_hint: PsType | TypeHint) -> PsType | None: + # Whatever type the enclosing context is to be, + # the type of `arr` has to be dereferencable to it + arr_tc.apply_hint(DereferencableTo(type_or_hint)) + + # Now we know the type of the array + # -> pass its dereferenced version to the outer context + assert isinstance(arr_tc.target_type, PsDereferencableType) + return arr_tc.target_type.base_type + + tc.hook(subscript_hook) + + elif not isinstance(arr_tc.target_type, PsDereferencableType): raise TypificationError( "Type of subscript base is not subscriptable." ) - - tc.apply_dtype(arr_tc.target_type.base_type, expr) + else: + tc.apply_dtype(arr_tc.target_type.base_type, expr) index_tc = TypeContext() self.visit_expr(idx, index_tc) @@ -683,13 +704,25 @@ class Typifier: ) items_tc.apply_dtype(deconstify(elem_type)) tc.infer_dtype(expr) + + case DereferencableTo(elem_type_or_hint): + if isinstance(elem_type_or_hint, PsType): + items_tc.apply_dtype(deconstify(elem_type_or_hint)) + else: + items_tc.apply_hint(elem_type_or_hint) + tc.infer_dtype(expr) + + return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) + case ToDefault(): items_tc.apply_hint(type_or_hint) tc.infer_dtype(expr) return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) + case TypeHint(): # Can't deal with any other type hints return None + case other_type: raise TypificationError( f"Cannot apply type {other_type} to array initializer {expr}." diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 44e3b8fa6..6592f22be 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import ( PsLt, PsCall, PsTernary, + PsArrayInitList, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction @@ -220,7 +221,7 @@ def test_constant_decls(): typify = Typifier(ctx) x, y = sp.symbols("x, y") - + decl = freeze(Assignment(x, 3.0)) decl = typify(decl) assert ctx.get_symbol("x").dtype == Fp(16) @@ -250,6 +251,59 @@ def test_constant_array_decls(): assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2) +def test_inline_arrays_1d(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y = sp.symbols("x, y") + idx = TypedSymbol("idx", Int(32)) + + arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4))) + decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, freeze(idx))) + # The array elements should learn their type from the context, which gets it from `y` + + decl = typify(decl) + assert decl.lhs.dtype == Fp(16, const=True) + assert decl.rhs.dtype == Fp(16, const=True) + + assert arr.dtype == Arr(Fp(16), 4, const=True) + for item in arr.items: + assert item.dtype == Fp(16, const=True) + + +def test_inline_arrays_3d(): + ctx = KernelCreationContext(default_dtype=Fp(16)) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + x, y = sp.symbols("x, y") + idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)] + + arr = freeze(sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10)))) + decl = PsDeclaration( + freeze(x), + freeze(y) + + PsSubscript( + PsSubscript(PsSubscript(arr, freeze(idx[0])), freeze(idx[1])), + freeze(idx[2]), + ), + ) + # The array elements should learn their type from the context, which gets it from `y` + + decl = typify(decl) + assert decl.lhs.dtype == Fp(16, const=True) + assert decl.rhs.dtype == Fp(16, const=True) + + assert arr.dtype == Arr(Arr(Arr(Fp(16), 2), 3), 2, const=True) + for item in arr.items: + assert item.dtype == Arr(Arr(Fp(16), 2), 3, const=True) + for iitem in item.items: + assert iitem.dtype == Arr(Fp(16), 2, const=True) + for iiitem in iitem.items: + assert iiitem.dtype == Fp(16, const=True) + + def test_lhs_inference(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) -- GitLab