Skip to content
Snippets Groups Projects
Commit f907b330 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix minor type conflicts with predefined array types

parent 582784a2
No related branches found
No related tags found
1 merge request!418Nesting of Type Contexts, Type Hints, and Improved Array Typing
Pipeline #69406 passed
...@@ -73,12 +73,14 @@ class TypeHint: ...@@ -73,12 +73,14 @@ class TypeHint:
@dataclass(frozen=True) @dataclass(frozen=True)
class ToDefault(TypeHint): class ToDefault(TypeHint):
"""Indicates to fall back to a default type.""" """Indicates to fall back to a default type."""
default_dtype: PsType default_dtype: PsType
@dataclass(frozen=True) @dataclass(frozen=True)
class DereferencableTo(TypeHint): class DereferencableTo(TypeHint):
"""Indicates that the type has to be dereferencable to the given base type.""" """Indicates that the type has to be dereferencable to the given base type."""
base_type: PsType | TypeHint base_type: PsType | TypeHint
...@@ -111,7 +113,7 @@ class TypeContext: ...@@ -111,7 +113,7 @@ class TypeContext:
- Additional restrictions may be added in the future. - Additional restrictions may be added in the future.
**Target type** **Target type**
Each typing context needs to be assigned its target type at some point. Each typing context needs to be assigned its target type at some point.
The target type may be The target type may be
...@@ -120,8 +122,8 @@ class TypeContext: ...@@ -120,8 +122,8 @@ class TypeContext:
- inferred from the type of the enclosing context via an inference hook, or as a last resort - inferred from the type of the enclosing context via an inference hook, or as a last resort
- determined from a type hint applied to the enclosing context via an inference hook. - determined from a type hint applied to the enclosing context via an inference hook.
**Expansion** **Expansion**
Expression nodes are added to a type context using either `apply_dtype` or `infer_dtype`. Expression nodes are added to a type context using either `apply_dtype` or `infer_dtype`.
In both cases, the context's target type will be applied to the node, In both cases, the context's target type will be applied to the node,
unless it already has a conflicting type. unless it already has a conflicting type.
...@@ -160,7 +162,7 @@ class TypeContext: ...@@ -160,7 +162,7 @@ class TypeContext:
@property @property
def target_type(self) -> PsType | None: def target_type(self) -> PsType | None:
return self._target_type return self._target_type
def get_target_type(self) -> PsType: def get_target_type(self) -> PsType:
assert self._target_type is not None assert self._target_type is not None
return self._target_type return self._target_type
...@@ -209,7 +211,7 @@ class TypeContext: ...@@ -209,7 +211,7 @@ class TypeContext:
def apply_hint(self, hint: TypeHint): def apply_hint(self, hint: TypeHint):
"""Attempt to resolve this type context from the given type hint. """Attempt to resolve this type context from the given type hint.
If the hint is not sufficient to resolve the context, a `TypificationError` is raised. If the hint is not sufficient to resolve the context, a `TypificationError` is raised.
""" """
assert self._target_type is None assert self._target_type is None
...@@ -230,7 +232,9 @@ class TypeContext: ...@@ -230,7 +232,9 @@ class TypeContext:
# Now we have the target type # Now we have the target type
self._propagate_target_type() self._propagate_target_type()
else: else:
raise TypificationError(f"Unable to infer context type from hint {hint}") raise TypificationError(
f"Unable to infer context type from hint {hint}"
)
def infer_dtype(self, expr: PsExpression): def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression. """Infer the data type for the given expression.
...@@ -250,7 +254,7 @@ class TypeContext: ...@@ -250,7 +254,7 @@ class TypeContext:
def _propagate_target_type(self): def _propagate_target_type(self):
"""Propagates the target type to any registered inference hooks and applies it to any deferred nodes. """Propagates the target type to any registered inference hooks and applies it to any deferred nodes.
Call after the target type of this context has been set. Call after the target type of this context has been set.
""" """
assert self._target_type is not None assert self._target_type is not None
...@@ -551,7 +555,10 @@ class Typifier: ...@@ -551,7 +555,10 @@ class Typifier:
self.visit_expr(arr, arr_tc) self.visit_expr(arr, arr_tc)
if arr_tc.target_type is None: if arr_tc.target_type is None:
def subscript_hook(type_or_hint: PsType | TypeHint) -> PsType | None:
def subscript_hook(
type_or_hint: PsType | TypeHint,
) -> PsType | None:
# Whatever type the enclosing context is to be, # Whatever type the enclosing context is to be,
# the type of `arr` has to be dereferencable to it # the type of `arr` has to be dereferencable to it
arr_tc.apply_hint(DereferencableTo(type_or_hint)) arr_tc.apply_hint(DereferencableTo(type_or_hint))
...@@ -560,7 +567,7 @@ class Typifier: ...@@ -560,7 +567,7 @@ class Typifier:
# -> pass its dereferenced version to the outer context # -> pass its dereferenced version to the outer context
assert isinstance(arr_tc.target_type, PsDereferencableType) assert isinstance(arr_tc.target_type, PsDereferencableType)
return arr_tc.target_type.base_type return arr_tc.target_type.base_type
tc.hook(subscript_hook) tc.hook(subscript_hook)
elif not isinstance(arr_tc.target_type, PsDereferencableType): elif not isinstance(arr_tc.target_type, PsDereferencableType):
...@@ -667,7 +674,7 @@ class Typifier: ...@@ -667,7 +674,7 @@ class Typifier:
if args_tc.target_type is None: if args_tc.target_type is None:
args_tc.apply_hint(ToDefault(self._ctx.default_dtype)) args_tc.apply_hint(ToDefault(self._ctx.default_dtype))
if not isinstance(args_tc.target_type, PsNumericType): if not isinstance(args_tc.target_type, PsNumericType):
raise TypificationError( raise TypificationError(
f"Invalid type in arguments to relation\n" f"Invalid type in arguments to relation\n"
...@@ -707,6 +714,23 @@ class Typifier: ...@@ -707,6 +714,23 @@ class Typifier:
case PsArrayInitList(items): case PsArrayInitList(items):
items_tc = TypeContext() items_tc = TypeContext()
def propagate_elem_type(elem_type: PsType, length: int | None):
if length is not None and length != len(items):
raise TypificationError(
"Array size mismatch: Cannot typify initializer list with "
f"{len(items)} items as {tc.target_type}"
)
items_tc.apply_dtype(deconstify(elem_type))
# If the enclosing context already prescribes an array type,
# eagerly propagate it to the items-context
if isinstance(tc.target_type, PsArrayType):
inherit_arr_type = True
propagate_elem_type(tc.target_type.base_type, tc.target_type.length)
else:
inherit_arr_type = False
for item in items: for item in items:
self.visit_expr(item, items_tc) self.visit_expr(item, items_tc)
...@@ -715,12 +739,7 @@ class Typifier: ...@@ -715,12 +739,7 @@ class Typifier:
def hook(type_or_hint: PsType | TypeHint) -> PsType | None: def hook(type_or_hint: PsType | TypeHint) -> PsType | None:
match type_or_hint: match type_or_hint:
case PsArrayType(elem_type, length): case PsArrayType(elem_type, length):
if length is not None and length != len(items): propagate_elem_type(elem_type, length)
raise TypificationError(
"Array size mismatch: Cannot typify initializer list with "
f"{len(items)} items as {tc.target_type}"
)
items_tc.apply_dtype(deconstify(elem_type))
tc.infer_dtype(expr) tc.infer_dtype(expr)
return None return None
...@@ -731,25 +750,31 @@ class Typifier: ...@@ -731,25 +750,31 @@ class Typifier:
items_tc.apply_hint(elem_type_or_hint) items_tc.apply_hint(elem_type_or_hint)
tc.infer_dtype(expr) tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) return PsArrayType(
deconstify(items_tc.get_target_type()), len(items)
)
case ToDefault(): case ToDefault():
items_tc.apply_hint(type_or_hint) items_tc.apply_hint(type_or_hint)
tc.infer_dtype(expr) tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) return PsArrayType(
deconstify(items_tc.get_target_type()), len(items)
)
case TypeHint(): case TypeHint():
# Can't deal with any other type hints # Can't deal with any other type hints
return None return None
case other_type: case other_type:
raise TypificationError( raise TypificationError(
f"Cannot apply type {other_type} to array initializer {expr}." f"Cannot apply type {other_type} to array initializer {expr}."
) )
tc.hook(hook) tc.hook(hook)
elif inherit_arr_type:
tc.infer_dtype(expr)
else: else:
arr_type = PsArrayType(items_tc.target_type, len(items)) arr_type = PsArrayType(deconstify(items_tc.target_type), len(items))
tc.apply_dtype(arr_type, expr) tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg): case PsCast(dtype, arg):
......
...@@ -284,6 +284,35 @@ def test_constant_array_decls(): ...@@ -284,6 +284,35 @@ def test_constant_array_decls():
assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2) assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
def test_array_decl_lhs_type_propagation():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
# Type of array initializer is figured out to be half [4],
# but LHS symbol has type `half []` without shape information
# Expected behavior: LHS type overrides inferred type
arr = TypedSymbol("arr", Arr(Fp(16)))
decl = freeze(Assignment(arr, (5, 78, 1, TypedSymbol("x", Fp(16)))))
decl = typify(decl)
assert decl.rhs.dtype == constify(arr.dtype)
def test_array_decl_constness_conflict():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
# Type of array initializer is figured out to be half [4],
# but LHS symbol has fixed type `const half []`.
# This is still a valid declaration.
arr = TypedSymbol("arr", Arr(Fp(16, const=True)))
decl = freeze(Assignment(arr, (5, 78, 1, TypedSymbol("x", Fp(16)))))
decl = typify(decl)
assert decl.rhs.dtype == constify(Arr(Fp(16, const=True)))
def test_inline_arrays_1d(): def test_inline_arrays_1d():
ctx = KernelCreationContext(default_dtype=Fp(16)) ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx) freeze = FreezeExpressions(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment