Skip to content
Snippets Groups Projects
Commit 45c9fe80 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

contextual typing

parent 0c1ef7fb
No related branches found
No related tags found
No related merge requests found
......@@ -35,11 +35,11 @@ class FreezeExpressions(SympyToPymbolicMapper):
...
@overload
def __call__(self, expr: Assignment) -> PsAssignment:
def __call__(self, expr: sp.Expr) -> PsExpression:
...
@overload
def __call__(self, expr: sp.Basic) -> pb.Expression:
def __call__(self, expr: Assignment) -> PsAssignment:
...
def __call__(self, obj):
......@@ -47,8 +47,8 @@ class FreezeExpressions(SympyToPymbolicMapper):
return PsBlock([self.rec(asm) for asm in obj.all_assignments])
elif isinstance(obj, Assignment):
return cast(PsAssignment, self.rec(obj))
elif isinstance(obj, sp.Basic):
return cast(pb.Expression, self.rec(obj))
elif isinstance(obj, sp.Expr):
return PsExpression(cast(pb.Expression, self.rec(obj)))
else:
raise PsInputError(f"Don't know how to freeze {obj}")
......
from __future__ import annotations
from typing import TypeVar, Any, Sequence, cast
from typing import TypeVar, Any, NoReturn
import pymbolic.primitives as pb
from pymbolic.mapper import Mapper
......@@ -10,6 +10,9 @@ from ..types import PsAbstractType, PsNumericType, deconstify
from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from ..arrays import PsArrayAccess
from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
from ..exceptions import PsInternalCompilerError
__all__ = ["Typifier"]
class TypificationError(Exception):
......@@ -19,6 +22,72 @@ class TypificationError(Exception):
NodeT = TypeVar("NodeT", bound=PsAstNode)
class UndeterminedType(PsNumericType):
def create_constant(self, value: Any) -> Any:
return None
def _err(self) -> NoReturn:
raise PsInternalCompilerError("Calling UndeterminedType.")
def create_literal(self, value: Any) -> str:
self._err()
def is_int(self) -> bool:
self._err()
def is_sint(self) -> bool:
self._err()
def is_uint(self) -> bool:
self._err()
def is_float(self) -> bool:
self._err()
def __eq__(self, other: object) -> bool:
self._err()
def _c_string(self) -> str:
self._err()
class DeferredTypedConstant(PsTypedConstant):
"""Special subclass for constants whose types cannot be determined yet at the time of their creation.
Outside of the typifier, a DeferredTypedConstant acts exactly the same way as a PsTypedConstant.
"""
def __init__(self, value: Any):
self._value_deferred = value
def resolve(self, dtype: PsNumericType):
super().__init__(self._value_deferred, dtype)
class TypeContext:
def __init__(self, target_type: PsNumericType | None):
self._target_type = deconstify(target_type) if target_type is not None else None
self._deferred_constants: list[DeferredTypedConstant] = []
def make_constant(self, value: Any) -> PsTypedConstant:
if self._target_type is None:
dc = DeferredTypedConstant(value)
self._deferred_constants.append(dc)
return dc
else:
return PsTypedConstant(value, self._target_type)
def apply(self, target_type: PsNumericType):
assert self._target_type is None, "Type context was already resolved"
self._target_type = deconstify(target_type)
for dc in self._deferred_constants:
dc.resolve(self._target_type)
@property
def target_type(self) -> PsNumericType | None:
return self._target_type
class Typifier(Mapper):
"""Typifier for untyped expressions.
......@@ -27,14 +96,33 @@ class Typifier(Mapper):
- Plain variables will be assigned a type according to `ctx.options.default_dtype`.
- Constants will be converted to typed constants by applying the target type of the current context.
If the target type is unknown, typification of constants will fail.
The target type for an expression must either be provided by the user or is inferred from the context.
The two primary contexts are an assignment, where the target type of the right-hand side expression is
given by the type of the left-hand side; and the index expression of an array access, where the target
type is given by `ctx.options.index_dtype`.
The target type is propagated upward through the expression tree. It is applied to all untyped constants,
and used to check the correctness of the types of expressions.
Contextual Typing
-----------------
Starting at an expression's root, the typifier attempts to expand a typing context as far as possible.
This happens implicitly during the recursive traversal of the expression tree.
At an interior node, which is modelled as a function applied to a number of arguments, producing a result,
that function's signature governs context expansion. Let T be the function's return type; then the context
is expanded to each argument expression that also is of type T.
If a function parameter is of type S != T, a new type context is created for it. If the type S is already fixed
by the function signature, it will be the target type of the new context.
At the tree's leaves, types are applied and checked. By the above propagation rule, all leaves that share a typing
context must have the exact same type (modulo constness). This type is checked at variables, and applied to
constants.
It may happen that the typifier arrives at a constant before the context's target type could be figured out.
In that case, the constant will first be instantiated as a DeferredTypedConstant, and stashed in the context.
As soon as the context learns its target type, it is applied to all deferred constants.
When a context is 'closed' during the recursive unwinding, it shall be an error if it still contains unresolved
constants.
TODO: The context shall keep track of it's target type's origin to aid in producing helpful error messages.
"""
def __init__(self, ctx: KernelCreationContext):
......@@ -46,18 +134,15 @@ class Typifier(Mapper):
node.statements = [self(s) for s in statements]
case PsExpression(expr):
node.expression, _ = self.rec(expr)
node.expression = self.rec(expr, TypeContext(None))
case PsAssignment(lhs, rhs):
new_lhs, lhs_dtype = self.rec(lhs.expression, None)
new_rhs, rhs_dtype = self.rec(rhs.expression, lhs_dtype)
if lhs_dtype != rhs_dtype:
raise TypificationError(
"Mismatched types in assignment: \n"
f" {lhs} <- {rhs}\n"
f" dtype(lhs) = {lhs_dtype}\n"
f" dtype(rhs) = {rhs_dtype}\n"
)
tc = TypeContext(None)
# LHS defines target type; type context carries it to RHS
new_lhs = self.rec(lhs.expression, tc)
assert tc.target_type is not None
new_rhs = self.rec(rhs.expression, tc)
node.lhs.expression = new_lhs
node.rhs.expression = new_rhs
......@@ -67,7 +152,7 @@ class Typifier(Mapper):
return node
"""
def rec(self, expr: Any, target_type: PsNumericType | None)
def rec(self, expr: Any, tc: TypeContext) -> ExprOrConstant
All visitor methods take an expression and the target type of the current context.
They shall return the typified expression together with its type.
......@@ -77,106 +162,59 @@ class Typifier(Mapper):
def typify_expression(
self, expr: Any, target_type: PsNumericType | None = None
) -> ExprOrConstant:
return self.rec(expr, target_type)
return self.rec(expr, TypeContext(target_type))
# Leaf nodes: Variables, Typed Variables, Constants and TypedConstants
def map_typed_variable(
self, var: PsTypedVariable, target_type: PsNumericType | None
):
self._check_target_type(var, var.dtype, target_type)
return var, deconstify(var.dtype)
def map_typed_variable(self, var: PsTypedVariable, tc: TypeContext):
self._apply_target_type(var, var.dtype, tc)
return var
def map_variable(
self, var: pb.Variable, target_type: PsNumericType | None
) -> tuple[PsTypedVariable, PsNumericType]:
def map_variable(self, var: pb.Variable, tc: TypeContext) -> PsTypedVariable:
dtype = self._ctx.options.default_dtype
typed_var = PsTypedVariable(var.name, dtype)
self._check_target_type(typed_var, dtype, target_type)
return typed_var, deconstify(dtype)
self._apply_target_type(typed_var, dtype, tc)
return typed_var
def map_constant(
self, value: Any, target_type: PsNumericType | None
) -> tuple[PsTypedConstant, PsNumericType]:
def map_constant(self, value: Any, tc: TypeContext) -> PsTypedConstant:
if isinstance(value, PsTypedConstant):
self._check_target_type(value, value.dtype, target_type)
return value, deconstify(value.dtype)
elif target_type is None:
raise TypificationError(
f"Unable to typify constant {value}: Unknown target type in this context."
)
else:
return PsTypedConstant(value, target_type), deconstify(target_type)
self._apply_target_type(value, value.dtype, tc)
return value
return tc.make_constant(value)
# Array Access
def map_array_access(
self, access: PsArrayAccess, target_type: PsNumericType | None
) -> tuple[PsArrayAccess, PsNumericType]:
self._check_target_type(access, access.dtype, target_type)
index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
return PsArrayAccess(access.base_ptr, index), cast(
PsNumericType, deconstify(access.dtype)
def map_array_access(self, access: PsArrayAccess, tc: TypeContext) -> PsArrayAccess:
self._apply_target_type(access, access.dtype, tc)
index, _ = self.rec(
access.index_tuple[0], TypeContext(self._ctx.options.index_dtype)
)
return PsArrayAccess(access.base_ptr, index)
# Arithmetic Expressions
def _homogenize(
self,
expr: pb.Expression,
args: Sequence[Any],
target_type: PsNumericType | None,
) -> tuple[tuple[ExprOrConstant], PsNumericType]:
"""Typify all arguments of a multi-argument expression with the same type."""
new_args = [None] * len(args)
common_type: PsNumericType | None = None
for i, c in enumerate(args):
new_args[i], arg_i_type = self.rec(c, target_type)
if common_type is None:
common_type = arg_i_type
elif common_type != arg_i_type:
raise TypificationError(
f"Type mismatch in expression {expr}: Type of operand {i} did not match previous operands\n"
f" Previous type: {common_type}\n"
f" Operand {i} type: {arg_i_type}"
)
assert common_type is not None
return cast(tuple[ExprOrConstant], tuple(new_args)), common_type
def map_sum(
self, expr: pb.Sum, target_type: PsNumericType | None
) -> tuple[pb.Sum, PsNumericType]:
new_args, dtype = self._homogenize(expr, expr.children, target_type)
return pb.Sum(new_args), dtype
def map_product(
self, expr: pb.Product, target_type: PsNumericType | None
) -> tuple[pb.Product, PsNumericType]:
new_args, dtype = self._homogenize(expr, expr.children, target_type)
return pb.Product(new_args), dtype
def map_call(
self, expr: pb.Call, target_type: PsNumericType | None
) -> tuple[pb.Call, PsNumericType]:
"""
TODO: Figure out the best way to typify functions
def map_sum(self, expr: pb.Sum, tc: TypeContext) -> pb.Sum:
return pb.Sum(tuple(self.rec(c, tc) for c in expr.children))
- How to propagate target_type in the face of multiple overloads?
def map_product(self, expr: pb.Product, tc: TypeContext) -> pb.Product:
return pb.Product(tuple(self.rec(c, tc) for c in expr.children))
def map_call(self, expr: pb.Call, tc: TypeContext) -> pb.Call:
"""
TODO: Figure out how to describe function signatures
"""
raise NotImplementedError()
def _check_target_type(
self,
expr: ExprOrConstant,
expr_type: PsAbstractType,
target_type: PsNumericType | None,
def _apply_target_type(
self, expr: ExprOrConstant, expr_type: PsAbstractType, tc: TypeContext
):
if target_type is not None and deconstify(expr_type) != deconstify(target_type):
if tc.target_type is None:
assert isinstance(expr_type, PsNumericType)
tc.apply(expr_type)
elif deconstify(expr_type) != tc.target_type:
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr_type}\n"
f" Target type: {target_type}"
f" Target type: {tc.target_type}"
)
......@@ -35,12 +35,36 @@ def test_typify_simple():
case PsTypedVariable(name, dtype):
assert name in "xyz"
assert dtype == ctx.options.default_dtype
case pb.Variable:
pytest.fail("Encountered untyped variable")
case pb.Sum(cs) | pb.Product(cs):
[check(c) for c in cs]
case _:
pytest.fail("Non-exhaustive pattern matcher.")
pytest.fail(f"Unexpected expression: {expr}")
check(fasm.lhs.expression)
check(fasm.rhs.expression)
def test_contextual_typing():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
expr = freeze(2 * x + 3 * y + z - 4)
expr = typify(expr)
def check(expr):
match expr:
case PsTypedConstant(value, dtype):
assert value in (2, 3, -4)
assert dtype == constify(ctx.options.default_dtype)
case PsTypedVariable(name, dtype):
assert name in "xyz"
assert dtype == ctx.options.default_dtype
case pb.Sum(cs) | pb.Product(cs):
[check(c) for c in cs]
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr.expression)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment