Skip to content
Snippets Groups Projects
Commit fbcf566f authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Base implementation of inference hooks + default-fallback mechanism via these hooks.

parent a4fb0c00
No related branches found
No related tags found
1 merge request!418Nesting of Type Contexts, Type Hints, and Improved Array Typing
......@@ -22,8 +22,8 @@ def optimize_cpu(
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
hoist_invariants = HoistLoopInvariantDeclarations(ctx)
kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
# hoist_invariants = HoistLoopInvariantDeclarations(ctx)
# kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
if cfg is None:
return kernel_ast
......
from __future__ import annotations
from typing import TypeVar
from typing import TypeVar, Callable
from dataclasses import dataclass
from .context import KernelCreationContext
from ...types import (
......@@ -23,6 +24,7 @@ from ..ast.structural import (
PsExpression,
PsAssignment,
PsDeclaration,
PsStatement,
PsEmptyLeafMixIn,
)
from ..ast.expressions import (
......@@ -58,6 +60,38 @@ class TypificationError(Exception):
NodeT = TypeVar("NodeT", bound=PsAstNode)
@dataclass(frozen=True)
class TypeHint:
"""Base class for type hints.
Type hints represent incomplete type information that can be passed up a tree of unresolved type contexts
in order to attempt to resolve them.
"""
pass
@dataclass(frozen=True)
class ToDefault(TypeHint):
default_dtype: PsType
InferenceHook = Callable[[PsType | TypeHint], PsType | None]
"""An inference hook is a callback that is attached to a type context,
to be called once type information about that context is known.
The inference hook will then try to use that information to resolve nested type contexts,
and potentially the context it is attached to as well.
When called with a `PsType`, that type is the target type of the context to which the hook is attached.
The hook has to use this type to resolve any nested type contexts and return `None`;
if it cannot resolve its nested contexts, it must raise a TypificationError.
When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts.
If it succeeds, it has to return the data type that must be applied to the outer context.
If it fails, it must return `None`.
"""
class TypeContext:
"""Typing context, with support for type inference and checking.
......@@ -68,6 +102,25 @@ class TypeContext:
- A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future.
Each typing context needs to be resolved at some point.
This can happen immediately during its expansion in the following ways:
- During expansion, a node with a fixed type is encountered; then, that type is applied to the context; or
- A type is enforced directly from a surrounding or otherwise associated context; or
- the context is supplied with a type hint through `apply_hint`, which it can either use to figure out its type
directly, or pass on to any number of registered `InferenceHook`s;
one of those *must* then provide the target type.
If a type context cannot be resolved while it is being processed, its resolution needs to be deferred.
It must then be hooked into its surrounding (parent) context using an `InferenceHook`.
Through this hook, it receives two second chances for resolution:
- If the surrounding context gets resolved to a type, the type of the nested context *must* be inferred
from that surrounding type
- If the surrounding context is supplied with a type hint, that type hint is given to the inference hook
which then *may* use that to infer the type of its nested context;
it that is successful, the hook must also provide the type for the surrounding context.
"""
def __init__(
......@@ -82,20 +135,37 @@ class TypeContext:
self._fix_constness(target_type) if target_type is not None else None
)
self._inference_hooks: list[InferenceHook] = []
@property
def target_type(self) -> PsType | None:
return self._target_type
def get_target_type(self) -> PsType:
assert self._target_type is not None
return self._target_type
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
def hook(self, hook: InferenceHook):
"""Add an inference hook to this type context.
If this context already has a known target type, the inference hook will be called immediately.
If it does not, the inference hook is cached to be called later when a type or type hint becomes known.
"""
if self._target_type is not None:
hook(self._target_type)
else:
self._inference_hooks.append(hook)
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions.
If the target type is still unknown, target_type is set to dtype, retroactively applied
to all deferred expressions, and propagated through any registered inference hooks.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
......@@ -117,6 +187,30 @@ class TypeContext:
if expr is not None:
self._apply_target_type(expr)
def apply_hint(self, hint: TypeHint):
"""Attempt to resolve this type context from the given type hint.
If the hint is not sufficient to resolve the context, a `TypificationError` is raised.
"""
assert self._target_type is None
# Type hints that can be resolved right there
match hint:
case ToDefault(default_dtype) if not self._inference_hooks:
self.apply_dtype(default_dtype)
case _:
for i, hook in enumerate(self._inference_hooks):
self._target_type = hook(hint)
if self._target_type is not None:
# That hook was successful; remove it so it is not called a second time
del self._inference_hooks[i]
if self._target_type is not None:
# Now we have the target type
self._propagate_target_type()
else:
raise TypificationError(f"Unable to infer context type from hint {hint}")
def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression.
......@@ -134,8 +228,14 @@ class TypeContext:
self._apply_target_type(expr)
def _propagate_target_type(self):
assert self._target_type is not None
for hook in self._inference_hooks:
hook(self._target_type)
for expr in self._deferred_exprs:
self._apply_target_type(expr)
self._deferred_exprs = []
def _apply_target_type(self, expr: PsExpression):
......@@ -293,7 +393,7 @@ class Typifier:
if tc.target_type is None:
# no type could be inferred -> take the default
tc.apply_dtype(self._ctx.default_dtype)
tc.apply_hint(ToDefault(self._ctx.default_dtype))
else:
self.visit(node)
return node
......@@ -332,7 +432,7 @@ class Typifier:
if infer_lhs and tc.target_type is None:
# no type has been inferred -> use the default dtype
tc.apply_dtype(self._ctx.default_dtype)
tc.apply_hint(ToDefault(self._ctx.default_dtype))
case PsAssignment(lhs, rhs):
infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
......@@ -350,8 +450,8 @@ class Typifier:
if infer_lhs:
if tc_rhs.target_type is None:
tc_rhs.apply_dtype(self._ctx.default_dtype)
tc_rhs.apply_hint(ToDefault(self._ctx.default_dtype))
assert tc_rhs.target_type is not None
tc_lhs.apply_dtype(deconstify(tc_rhs.target_type))
......@@ -572,23 +672,30 @@ class Typifier:
self.visit_expr(item, items_tc)
if items_tc.target_type is None:
if tc.target_type is None:
raise TypificationError(f"Unable to infer type of array {expr}")
elif not isinstance(tc.target_type, PsArrayType):
raise TypificationError(
f"Cannot apply type {tc.target_type} to an array initializer."
)
elif (
tc.target_type.length is not None
and tc.target_type.length != len(items)
):
raise TypificationError(
"Array size mismatch: Cannot typify initializer list with "
f"{len(items)} items as {tc.target_type}"
)
else:
items_tc.apply_dtype(tc.target_type.base_type)
tc.infer_dtype(expr)
# Infer type of items from enclosing context
def hook(type_or_hint: PsType | TypeHint) -> PsType | None:
match type_or_hint:
case PsArrayType(elem_type, length):
if length is not None and length != len(items):
raise TypificationError(
"Array size mismatch: Cannot typify initializer list with "
f"{len(items)} items as {tc.target_type}"
)
items_tc.apply_dtype(deconstify(elem_type))
tc.infer_dtype(expr)
case ToDefault():
items_tc.apply_hint(type_or_hint)
tc.infer_dtype(expr)
return PsArrayType(deconstify(items_tc.get_target_type()), len(items))
case TypeHint():
# Can't deal with any other type hints
return None
case other_type:
raise TypificationError(
f"Cannot apply type {other_type} to array initializer {expr}."
)
tc.hook(hook)
else:
arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_dtype(arr_type, expr)
......
......@@ -100,6 +100,8 @@ class PsPointerType(PsDereferencableType):
class PsArrayType(PsDereferencableType):
"""C array type of known or unknown size."""
__match_args__ = ("base_type", "length")
def __init__(
self, base_type: PsType, length: int | None = None, const: bool = False
):
......
......@@ -214,6 +214,42 @@ def test_default_typing():
check(expr)
def test_constant_decls():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
decl = freeze(Assignment(x, 3.0))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Fp(16)
assert decl.rhs.dtype == Fp(16, const=True)
assert decl.rhs.constant.dtype == Fp(16, const=True)
decl = freeze(Assignment(y, 42))
decl = typify(decl)
assert ctx.get_symbol("y").dtype == Fp(16)
assert decl.rhs.dtype == Fp(16, const=True)
assert decl.rhs.constant.dtype == Fp(16, const=True)
def test_constant_array_decls():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
decl = freeze(Assignment(x, (1, 2, 3, 4)))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Arr(Fp(16), 4)
decl = freeze(Assignment(y, ((1, 2, 3, 4), (5, 6, 7, 8))))
decl = typify(decl)
assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
def test_lhs_inference():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment