Skip to content
Snippets Groups Projects
Commit c83b06a5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Introduce DereferencableTo type hint and implement inference hooks for (nested) arrays

parent fbcf566f
No related branches found
No related tags found
1 merge request!418Nesting of Type Contexts, Type Hints, and Improved Array Typing
...@@ -73,9 +73,16 @@ class TypeHint: ...@@ -73,9 +73,16 @@ class TypeHint:
@dataclass(frozen=True) @dataclass(frozen=True)
class ToDefault(TypeHint): class ToDefault(TypeHint):
"""Indicates to fall back to a default type."""
default_dtype: PsType default_dtype: PsType
@dataclass(frozen=True)
class DereferencableTo(TypeHint):
"""Indicates that the type has to be dereferencable to the given base type."""
base_type: PsType | TypeHint
InferenceHook = Callable[[PsType | TypeHint], PsType | None] InferenceHook = Callable[[PsType | TypeHint], PsType | None]
"""An inference hook is a callback that is attached to a type context, """An inference hook is a callback that is attached to a type context,
to be called once type information about that context is known. to be called once type information about that context is known.
...@@ -83,8 +90,8 @@ The inference hook will then try to use that information to resolve nested type ...@@ -83,8 +90,8 @@ The inference hook will then try to use that information to resolve nested type
and potentially the context it is attached to as well. and potentially the context it is attached to as well.
When called with a `PsType`, that type is the target type of the context to which the hook is attached. When called with a `PsType`, that type is the target type of the context to which the hook is attached.
The hook has to use this type to resolve any nested type contexts and return `None`; The hook has to use this type to resolve any nested type contexts and must return either `None` or the same data type.
if it cannot resolve its nested contexts, it must raise a TypificationError. If it cannot resolve its nested contexts, it must raise a TypificationError.
When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts. When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts.
If it succeeds, it has to return the data type that must be applied to the outer context. If it succeeds, it has to return the data type that must be applied to the outer context.
...@@ -200,8 +207,9 @@ class TypeContext: ...@@ -200,8 +207,9 @@ class TypeContext:
self.apply_dtype(default_dtype) self.apply_dtype(default_dtype)
case _: case _:
for i, hook in enumerate(self._inference_hooks): for i, hook in enumerate(self._inference_hooks):
self._target_type = hook(hint) target_type = hook(hint)
if self._target_type is not None: if target_type is not None:
self._target_type = self._fix_constness(target_type)
# That hook was successful; remove it so it is not called a second time # That hook was successful; remove it so it is not called a second time
del self._inference_hooks[i] del self._inference_hooks[i]
...@@ -523,12 +531,25 @@ class Typifier: ...@@ -523,12 +531,25 @@ class Typifier:
arr_tc = TypeContext() arr_tc = TypeContext()
self.visit_expr(arr, arr_tc) self.visit_expr(arr, arr_tc)
if not isinstance(arr_tc.target_type, PsDereferencableType): if arr_tc.target_type is None:
def subscript_hook(type_or_hint: PsType | TypeHint) -> PsType | None:
# Whatever type the enclosing context is to be,
# the type of `arr` has to be dereferencable to it
arr_tc.apply_hint(DereferencableTo(type_or_hint))
# Now we know the type of the array
# -> pass its dereferenced version to the outer context
assert isinstance(arr_tc.target_type, PsDereferencableType)
return arr_tc.target_type.base_type
tc.hook(subscript_hook)
elif not isinstance(arr_tc.target_type, PsDereferencableType):
raise TypificationError( raise TypificationError(
"Type of subscript base is not subscriptable." "Type of subscript base is not subscriptable."
) )
else:
tc.apply_dtype(arr_tc.target_type.base_type, expr) tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext() index_tc = TypeContext()
self.visit_expr(idx, index_tc) self.visit_expr(idx, index_tc)
...@@ -683,13 +704,25 @@ class Typifier: ...@@ -683,13 +704,25 @@ class Typifier:
) )
items_tc.apply_dtype(deconstify(elem_type)) items_tc.apply_dtype(deconstify(elem_type))
tc.infer_dtype(expr) tc.infer_dtype(expr)
case DereferencableTo(elem_type_or_hint):
if isinstance(elem_type_or_hint, PsType):
items_tc.apply_dtype(deconstify(elem_type_or_hint))
else:
items_tc.apply_hint(elem_type_or_hint)
tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
case ToDefault(): case ToDefault():
items_tc.apply_hint(type_or_hint) items_tc.apply_hint(type_or_hint)
tc.infer_dtype(expr) tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items)) return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
case TypeHint(): case TypeHint():
# Can't deal with any other type hints # Can't deal with any other type hints
return None return None
case other_type: case other_type:
raise TypificationError( raise TypificationError(
f"Cannot apply type {other_type} to array initializer {expr}." f"Cannot apply type {other_type} to array initializer {expr}."
......
...@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import ( ...@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import (
PsLt, PsLt,
PsCall, PsCall,
PsTernary, PsTernary,
PsArrayInitList,
) )
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction from pystencils.backend.functions import CFunction
...@@ -220,7 +221,7 @@ def test_constant_decls(): ...@@ -220,7 +221,7 @@ def test_constant_decls():
typify = Typifier(ctx) typify = Typifier(ctx)
x, y = sp.symbols("x, y") x, y = sp.symbols("x, y")
decl = freeze(Assignment(x, 3.0)) decl = freeze(Assignment(x, 3.0))
decl = typify(decl) decl = typify(decl)
assert ctx.get_symbol("x").dtype == Fp(16) assert ctx.get_symbol("x").dtype == Fp(16)
...@@ -250,6 +251,59 @@ def test_constant_array_decls(): ...@@ -250,6 +251,59 @@ def test_constant_array_decls():
assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2) assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
def test_inline_arrays_1d():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
idx = TypedSymbol("idx", Int(32))
arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, freeze(idx)))
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16, const=True)
assert decl.rhs.dtype == Fp(16, const=True)
assert arr.dtype == Arr(Fp(16), 4, const=True)
for item in arr.items:
assert item.dtype == Fp(16, const=True)
def test_inline_arrays_3d():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
arr = freeze(sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))))
decl = PsDeclaration(
freeze(x),
freeze(y)
+ PsSubscript(
PsSubscript(PsSubscript(arr, freeze(idx[0])), freeze(idx[1])),
freeze(idx[2]),
),
)
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16, const=True)
assert decl.rhs.dtype == Fp(16, const=True)
assert arr.dtype == Arr(Arr(Arr(Fp(16), 2), 3), 2, const=True)
for item in arr.items:
assert item.dtype == Arr(Arr(Fp(16), 2), 3, const=True)
for iitem in item.items:
assert iitem.dtype == Arr(Fp(16), 2, const=True)
for iiitem in iitem.items:
assert iiitem.dtype == Fp(16, const=True)
def test_lhs_inference(): def test_lhs_inference():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx) freeze = FreezeExpressions(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment