Skip to content
Snippets Groups Projects

Fix handling of constness in Typifier

Merged Frederik Hennig requested to merge fhennig/fix-const-typing into backend-rework
All threads resolved!
3 files
+ 236
52
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -12,7 +12,7 @@ from ...types import (
PsDereferencableType,
PsPointerType,
PsBoolType,
deconstify,
constify,
)
from ..ast.structural import (
PsAstNode,
@@ -21,6 +21,7 @@ from ..ast.structural import (
PsConditional,
PsExpression,
PsAssignment,
PsDeclaration,
)
from ..ast.expressions import (
PsArrayAccess,
@@ -54,20 +55,48 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class TypeContext:
def __init__(self, target_type: PsType | None = None):
self._target_type = deconstify(target_type) if target_type is not None else None
"""Typing context, with support for type inference and checking.
Instances of this class are used to propagate and check data types across expression subtrees
of the AST. Each type context has:
- A target type `target_type`, which shall be applied to all expressions it covers
- A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future.
"""
def __init__(
self, target_type: PsType | None = None, require_nonconst: bool = False
):
self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = []
def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
"""Applies the given ``dtype`` to the given expression inside this type context.
self._target_type = (
self._fix_constness(target_type) if target_type is not None else None
)
@property
def target_type(self) -> PsType | None:
return self._target_type
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
The given expression will be covered by this type context.
If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
dtype = deconstify(dtype)
dtype = self._fix_constness(dtype)
if self._target_type is not None and dtype != self._target_type:
raise TypificationError(
@@ -80,14 +109,7 @@ class TypeContext:
self._propagate_target_type()
if expr is not None:
if expr.dtype is None:
self._apply_target_type(expr)
elif deconstify(expr.dtype) != self._target_type:
raise TypificationError(
"Type conflict: Predefined expression type did not match the context's target type\n"
f" Expression type: {dtype}\n"
f" Target type: {self._target_type}"
)
self._apply_target_type(expr)
def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression.
@@ -96,7 +118,8 @@ class TypeContext:
Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
called on this context.
If the expression already has a data type set, it must be equal to the inferred type.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
if self._target_type is None:
@@ -113,7 +136,7 @@ class TypeContext:
assert self._target_type is not None
if expr.dtype is not None:
if deconstify(expr.dtype) != self.target_type:
if not self._compatible(expr.dtype):
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr.dtype}\n"
@@ -128,7 +151,7 @@ class TypeContext:
)
if c.dtype is None:
expr.constant = c.interpret_as(self._target_type)
elif deconstify(c.dtype) != self._target_type:
elif not self._compatible(c.dtype):
raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f" Constant type: {c.dtype}\n"
@@ -136,7 +159,13 @@ class TypeContext:
)
case PsSymbolExpr(symb):
symb.apply_dtype(self._target_type)
assert symb.dtype is not None
if not self._compatible(symb.dtype):
raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n"
f" Target type: {self._target_type}"
)
case (
PsIntDiv()
@@ -151,18 +180,42 @@ class TypeContext:
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
expr.dtype = self._target_type
# endif
expr.dtype = self._target_type
@property
def target_type(self) -> PsType | None:
return self._target_type
def _compatible(self, dtype: PsType):
"""Checks whether the given data type is compatible with the context's target type.
If the target type is ``const``, they must be equal up to const qualification;
if the target type is not ``const``, `dtype` must match it exactly.
"""
assert self._target_type is not None
if self._target_type.const:
return constify(dtype) == self._target_type
else:
return dtype == self._target_type
def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None):
if self._require_nonconst:
if dtype.const:
if expr is None:
raise TypificationError(
f"Type mismatch: Encountered {dtype} in non-constant context."
)
else:
raise TypificationError(
f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context."
)
return dtype
else:
return constify(dtype)
class Typifier:
"""Apply data types to expressions.
**Contextual Typing**
The Typifier will traverse the AST and apply a contextual typing scheme to figure out
the data types of all encountered expressions.
To this end, it covers each expression tree with a set of disjoint typing contexts.
@@ -183,6 +236,21 @@ class Typifier:
in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
as the target type is fixed.
**Typing Rules**
The following general rules apply:
- The context's `default_dtype` is applied to all untyped symbols
- By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
left-hand side
**Typing of symbol expressions**
Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
not necessarily their const-qualification.
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
but not vice versa.
"""
def __init__(self, ctx: KernelCreationContext):
@@ -213,13 +281,21 @@ class Typifier:
for s in statements:
self.visit(s)
case PsAssignment(lhs, rhs):
case PsDeclaration(lhs, rhs):
tc = TypeContext()
# LHS defines target type; type context carries it to RHS
self.visit_expr(lhs, tc)
assert tc.target_type is not None
self.visit_expr(rhs, tc)
case PsAssignment(lhs, rhs):
tc_lhs = TypeContext(require_nonconst=True)
self.visit_expr(lhs, tc_lhs)
assert tc_lhs.target_type is not None
tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False)
self.visit_expr(rhs, tc_rhs)
case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType(const=True))
self.visit_expr(cond, cond_tc)
@@ -233,10 +309,10 @@ class Typifier:
if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype)
tc = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc)
self.visit_expr(stop, tc)
self.visit_expr(step, tc)
tc_index = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc_index)
self.visit_expr(stop, tc_index)
self.visit_expr(step, tc_index)
self.visit(body)
@@ -244,24 +320,35 @@ class Typifier:
raise NotImplementedError(f"Can't typify {node}")
def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
"""Recursive processing of expression nodes"""
"""Recursive processing of expression nodes.
This method opens, expands, and closes typing contexts according to the respective expression's
typing rules. It may add or check restrictions only when opening or closing a type context.
The actual type inference and checking during context expansion are performed by the methods
of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling
either ``apply_dtype`` or ``infer_dtype``.
"""
match expr:
case PsSymbolExpr(_):
if expr.dtype is None:
tc.apply_dtype(expr, self._ctx.default_dtype)
else:
tc.apply_dtype(expr, expr.dtype)
if expr.symbol.dtype is None:
expr.symbol.dtype = self._ctx.default_dtype
case PsConstantExpr(_):
tc.infer_dtype(expr)
tc.apply_dtype(expr.symbol.dtype, expr)
case PsConstantExpr(c):
if c.dtype is not None:
tc.apply_dtype(c.dtype, expr)
else:
tc.infer_dtype(expr)
case PsArrayAccess(bptr, idx):
tc.apply_dtype(expr, bptr.array.element_type)
tc.apply_dtype(bptr.array.element_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype)
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
@@ -276,12 +363,12 @@ class Typifier:
"Type of subscript base is not subscriptable."
)
tc.apply_dtype(expr, arr_tc.target_type.base_type)
tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype)
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
@@ -296,7 +383,7 @@ class Typifier:
"Type of argument to a Deref is not dereferencable"
)
tc.apply_dtype(expr, ptr_tc.target_type.base_type)
tc.apply_dtype(ptr_tc.target_type.base_type, expr)
case PsAddressOf(arg):
arg_tc = TypeContext()
@@ -308,10 +395,11 @@ class Typifier:
)
ptr_type = PsPointerType(arg_tc.target_type, True)
tc.apply_dtype(expr, ptr_type)
tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None)
# Members of a struct type inherit the struct type's `const` qualifier
aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
self.visit_expr(aggr, aggr_tc)
aggr_type = aggr_tc.target_type
@@ -326,7 +414,11 @@ class Typifier:
f"Aggregate of type {aggr_type} does not have a member {member}."
)
tc.apply_dtype(expr, member.dtype)
member_type = member.dtype
if aggr_type.const:
member_type = constify(member_type)
tc.apply_dtype(member_type, expr)
case PsBinOp(op1, op2):
self.visit_expr(op1, tc)
@@ -365,14 +457,14 @@ class Typifier:
f"{len(items)} items as {tc.target_type}"
)
else:
items_tc.apply_dtype(None, tc.target_type.base_type)
items_tc.apply_dtype(tc.target_type.base_type)
else:
arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_dtype(expr, arr_type)
tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext())
tc.apply_dtype(expr, dtype)
tc.apply_dtype(dtype, expr)
case _:
raise NotImplementedError(f"Can't typify {expr}")
Loading