Skip to content
Snippets Groups Projects
Commit c83b06a5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Introduce DereferencableTo type hint and implement inference hooks for (nested) arrays

parent fbcf566f
No related branches found
No related tags found
1 merge request!418Nesting of Type Contexts, Type Hints, and Improved Array Typing
......@@ -73,9 +73,16 @@ class TypeHint:
@dataclass(frozen=True)
class ToDefault(TypeHint):
"""Indicates to fall back to a default type."""
default_dtype: PsType
@dataclass(frozen=True)
class DereferencableTo(TypeHint):
"""Indicates that the type has to be dereferencable to the given base type."""
base_type: PsType | TypeHint
InferenceHook = Callable[[PsType | TypeHint], PsType | None]
"""An inference hook is a callback that is attached to a type context,
to be called once type information about that context is known.
......@@ -83,8 +90,8 @@ The inference hook will then try to use that information to resolve nested type
and potentially the context it is attached to as well.
When called with a `PsType`, that type is the target type of the context to which the hook is attached.
The hook has to use this type to resolve any nested type contexts and return `None`;
if it cannot resolve its nested contexts, it must raise a TypificationError.
The hook has to use this type to resolve any nested type contexts and must return either `None` or the same data type.
If it cannot resolve its nested contexts, it must raise a TypificationError.
When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts.
If it succeeds, it has to return the data type that must be applied to the outer context.
......@@ -200,8 +207,9 @@ class TypeContext:
self.apply_dtype(default_dtype)
case _:
for i, hook in enumerate(self._inference_hooks):
self._target_type = hook(hint)
if self._target_type is not None:
target_type = hook(hint)
if target_type is not None:
self._target_type = self._fix_constness(target_type)
# That hook was successful; remove it so it is not called a second time
del self._inference_hooks[i]
......@@ -523,12 +531,25 @@ class Typifier:
arr_tc = TypeContext()
self.visit_expr(arr, arr_tc)
if not isinstance(arr_tc.target_type, PsDereferencableType):
if arr_tc.target_type is None:
def subscript_hook(type_or_hint: PsType | TypeHint) -> PsType | None:
# Whatever type the enclosing context is to be,
# the type of `arr` has to be dereferencable to it
arr_tc.apply_hint(DereferencableTo(type_or_hint))
# Now we know the type of the array
# -> pass its dereferenced version to the outer context
assert isinstance(arr_tc.target_type, PsDereferencableType)
return arr_tc.target_type.base_type
tc.hook(subscript_hook)
elif not isinstance(arr_tc.target_type, PsDereferencableType):
raise TypificationError(
"Type of subscript base is not subscriptable."
)
tc.apply_dtype(arr_tc.target_type.base_type, expr)
else:
tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
......@@ -683,13 +704,25 @@ class Typifier:
)
items_tc.apply_dtype(deconstify(elem_type))
tc.infer_dtype(expr)
case DereferencableTo(elem_type_or_hint):
if isinstance(elem_type_or_hint, PsType):
items_tc.apply_dtype(deconstify(elem_type_or_hint))
else:
items_tc.apply_hint(elem_type_or_hint)
tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
case ToDefault():
items_tc.apply_hint(type_or_hint)
tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
case TypeHint():
# Can't deal with any other type hints
return None
case other_type:
raise TypificationError(
f"Cannot apply type {other_type} to array initializer {expr}."
......
......@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import (
PsLt,
PsCall,
PsTernary,
PsArrayInitList,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
......@@ -220,7 +221,7 @@ def test_constant_decls():
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
decl = freeze(Assignment(x, 3.0))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Fp(16)
......@@ -250,6 +251,59 @@ def test_constant_array_decls():
assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
def test_inline_arrays_1d():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
idx = TypedSymbol("idx", Int(32))
arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, freeze(idx)))
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16, const=True)
assert decl.rhs.dtype == Fp(16, const=True)
assert arr.dtype == Arr(Fp(16), 4, const=True)
for item in arr.items:
assert item.dtype == Fp(16, const=True)
def test_inline_arrays_3d():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
arr = freeze(sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))))
decl = PsDeclaration(
freeze(x),
freeze(y)
+ PsSubscript(
PsSubscript(PsSubscript(arr, freeze(idx[0])), freeze(idx[1])),
freeze(idx[2]),
),
)
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16, const=True)
assert decl.rhs.dtype == Fp(16, const=True)
assert arr.dtype == Arr(Arr(Arr(Fp(16), 2), 3), 2, const=True)
for item in arr.items:
assert item.dtype == Arr(Arr(Fp(16), 2), 3, const=True)
for iitem in item.items:
assert iitem.dtype == Arr(Fp(16), 2, const=True)
for iiitem in iitem.items:
assert iiitem.dtype == Fp(16, const=True)
def test_lhs_inference():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment