diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 1fede168a28809310d9224f1f093dbb712ce7e6b..432f68e0e7843e9cdd3d02aea750f72048719ffb 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -73,12 +73,14 @@ class TypeHint: @dataclass(frozen=True) class ToDefault(TypeHint): """Indicates to fall back to a default type.""" + default_dtype: PsType @dataclass(frozen=True) class DereferencableTo(TypeHint): """Indicates that the type has to be dereferencable to the given base type.""" + base_type: PsType | TypeHint @@ -111,7 +113,7 @@ class TypeContext: - Additional restrictions may be added in the future. **Target type** - + Each typing context needs to be assigned its target type at some point. The target type may be @@ -120,8 +122,8 @@ class TypeContext: - inferred from the type of the enclosing context via an inference hook, or as a last resort - determined from a type hint applied to the enclosing context via an inference hook. - **Expansion** - + **Expansion** + Expression nodes are added to a type context using either `apply_dtype` or `infer_dtype`. In both cases, the context's target type will be applied to the node, unless it already has a conflicting type. @@ -160,7 +162,7 @@ class TypeContext: @property def target_type(self) -> PsType | None: return self._target_type - + def get_target_type(self) -> PsType: assert self._target_type is not None return self._target_type @@ -209,7 +211,7 @@ class TypeContext: def apply_hint(self, hint: TypeHint): """Attempt to resolve this type context from the given type hint. - + If the hint is not sufficient to resolve the context, a `TypificationError` is raised. """ assert self._target_type is None @@ -230,7 +232,9 @@ class TypeContext: # Now we have the target type self._propagate_target_type() else: - raise TypificationError(f"Unable to infer context type from hint {hint}") + raise TypificationError( + f"Unable to infer context type from hint {hint}" + ) def infer_dtype(self, expr: PsExpression): """Infer the data type for the given expression. @@ -250,7 +254,7 @@ class TypeContext: def _propagate_target_type(self): """Propagates the target type to any registered inference hooks and applies it to any deferred nodes. - + Call after the target type of this context has been set. """ assert self._target_type is not None @@ -551,7 +555,10 @@ class Typifier: self.visit_expr(arr, arr_tc) if arr_tc.target_type is None: - def subscript_hook(type_or_hint: PsType | TypeHint) -> PsType | None: + + def subscript_hook( + type_or_hint: PsType | TypeHint, + ) -> PsType | None: # Whatever type the enclosing context is to be, # the type of `arr` has to be dereferencable to it arr_tc.apply_hint(DereferencableTo(type_or_hint)) @@ -560,7 +567,7 @@ class Typifier: # -> pass its dereferenced version to the outer context assert isinstance(arr_tc.target_type, PsDereferencableType) return arr_tc.target_type.base_type - + tc.hook(subscript_hook) elif not isinstance(arr_tc.target_type, PsDereferencableType): @@ -667,7 +674,7 @@ class Typifier: if args_tc.target_type is None: args_tc.apply_hint(ToDefault(self._ctx.default_dtype)) - + if not isinstance(args_tc.target_type, PsNumericType): raise TypificationError( f"Invalid type in arguments to relation\n" @@ -707,6 +714,23 @@ class Typifier: case PsArrayInitList(items): items_tc = TypeContext() + + def propagate_elem_type(elem_type: PsType, length: int | None): + if length is not None and length != len(items): + raise TypificationError( + "Array size mismatch: Cannot typify initializer list with " + f"{len(items)} items as {tc.target_type}" + ) + items_tc.apply_dtype(deconstify(elem_type)) + + # If the enclosing context already prescribes an array type, + # eagerly propagate it to the items-context + if isinstance(tc.target_type, PsArrayType): + inherit_arr_type = True + propagate_elem_type(tc.target_type.base_type, tc.target_type.length) + else: + inherit_arr_type = False + for item in items: self.visit_expr(item, items_tc) @@ -715,12 +739,7 @@ class Typifier: def hook(type_or_hint: PsType | TypeHint) -> PsType | None: match type_or_hint: case PsArrayType(elem_type, length): - if length is not None and length != len(items): - raise TypificationError( - "Array size mismatch: Cannot typify initializer list with " - f"{len(items)} items as {tc.target_type}" - ) - items_tc.apply_dtype(deconstify(elem_type)) + propagate_elem_type(elem_type, length) tc.infer_dtype(expr) return None @@ -731,25 +750,31 @@ class Typifier: items_tc.apply_hint(elem_type_or_hint) tc.infer_dtype(expr) - return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) + return PsArrayType( + deconstify(items_tc.get_target_type()), len(items) + ) case ToDefault(): items_tc.apply_hint(type_or_hint) tc.infer_dtype(expr) - return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) - + return PsArrayType( + deconstify(items_tc.get_target_type()), len(items) + ) + case TypeHint(): # Can't deal with any other type hints return None - + case other_type: raise TypificationError( f"Cannot apply type {other_type} to array initializer {expr}." ) tc.hook(hook) + elif inherit_arr_type: + tc.infer_dtype(expr) else: - arr_type = PsArrayType(items_tc.target_type, len(items)) + arr_type = PsArrayType(deconstify(items_tc.target_type), len(items)) tc.apply_dtype(arr_type, expr) case PsCast(dtype, arg): diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 0c707e33125d05cf5c1261ec68e899090f182bbb..8a25c16f4b2f5beb86a056fecf0d74f35438c562 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -284,6 +284,35 @@ def test_constant_array_decls(): assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2) +def test_array_decl_lhs_type_propagation(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + # Type of array initializer is figured out to be half [4], + # but LHS symbol has type `half []` without shape information + # Expected behavior: LHS type overrides inferred type + arr = TypedSymbol("arr", Arr(Fp(16))) + decl = freeze(Assignment(arr, (5, 78, 1, TypedSymbol("x", Fp(16))))) + decl = typify(decl) + assert decl.rhs.dtype == constify(arr.dtype) + + +def test_array_decl_constness_conflict(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + # Type of array initializer is figured out to be half [4], + # but LHS symbol has fixed type `const half []`. + # This is still a valid declaration. + + arr = TypedSymbol("arr", Arr(Fp(16, const=True))) + decl = freeze(Assignment(arr, (5, 78, 1, TypedSymbol("x", Fp(16))))) + decl = typify(decl) + assert decl.rhs.dtype == constify(Arr(Fp(16, const=True))) + + def test_inline_arrays_1d(): ctx = KernelCreationContext(default_dtype=Fp(16)) freeze = FreezeExpressions(ctx)