Skip to content
Snippets Groups Projects
Commit 191cc207 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'fhennig/fix-const-typing' into 'backend-rework'

Fix handling of constness in Typifier

See merge request !372
parents 6c649d20 7f5ffb5b
No related tags found
1 merge request!372Fix handling of constness in Typifier
Pipeline #64876 passed
......@@ -59,9 +59,14 @@ class PsConstant:
@property
def dtype(self) -> PsNumericType | None:
"""This constant's data type, or ``None`` if it is untyped.
The data type of a constant always has ``const == True``.
"""
return self._dtype
def get_dtype(self) -> PsNumericType:
"""Retrieve this constant's data type, throwing an exception if the constant is untyped."""
if self._dtype is None:
raise PsInternalCompilerError("Data type of constant was not set.")
return self._dtype
......
......@@ -12,7 +12,7 @@ from ...types import (
PsDereferencableType,
PsPointerType,
PsBoolType,
deconstify,
constify,
)
from ..ast.structural import (
PsAstNode,
......@@ -21,6 +21,7 @@ from ..ast.structural import (
PsConditional,
PsExpression,
PsAssignment,
PsDeclaration,
)
from ..ast.expressions import (
PsArrayAccess,
......@@ -54,20 +55,48 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class TypeContext:
def __init__(self, target_type: PsType | None = None):
self._target_type = deconstify(target_type) if target_type is not None else None
"""Typing context, with support for type inference and checking.
Instances of this class are used to propagate and check data types across expression subtrees
of the AST. Each type context has:
- A target type `target_type`, which shall be applied to all expressions it covers
- A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future.
"""
def __init__(
self, target_type: PsType | None = None, require_nonconst: bool = False
):
self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = []
def apply_dtype(self, expr: PsExpression | None, dtype: PsType):
"""Applies the given ``dtype`` to the given expression inside this type context.
self._target_type = (
self._fix_constness(target_type) if target_type is not None else None
)
@property
def target_type(self) -> PsType | None:
return self._target_type
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
The given expression will be covered by this type context.
If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
dtype = deconstify(dtype)
dtype = self._fix_constness(dtype)
if self._target_type is not None and dtype != self._target_type:
raise TypificationError(
......@@ -80,14 +109,7 @@ class TypeContext:
self._propagate_target_type()
if expr is not None:
if expr.dtype is None:
self._apply_target_type(expr)
elif deconstify(expr.dtype) != self._target_type:
raise TypificationError(
"Type conflict: Predefined expression type did not match the context's target type\n"
f" Expression type: {dtype}\n"
f" Target type: {self._target_type}"
)
self._apply_target_type(expr)
def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression.
......@@ -96,7 +118,8 @@ class TypeContext:
Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
called on this context.
If the expression already has a data type set, it must be equal to the inferred type.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
"""
if self._target_type is None:
......@@ -113,7 +136,7 @@ class TypeContext:
assert self._target_type is not None
if expr.dtype is not None:
if deconstify(expr.dtype) != self.target_type:
if not self._compatible(expr.dtype):
raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr.dtype}\n"
......@@ -128,7 +151,7 @@ class TypeContext:
)
if c.dtype is None:
expr.constant = c.interpret_as(self._target_type)
elif deconstify(c.dtype) != self._target_type:
elif not self._compatible(c.dtype):
raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f" Constant type: {c.dtype}\n"
......@@ -136,7 +159,13 @@ class TypeContext:
)
case PsSymbolExpr(symb):
symb.apply_dtype(self._target_type)
assert symb.dtype is not None
if not self._compatible(symb.dtype):
raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n"
f" Target type: {self._target_type}"
)
case (
PsIntDiv()
......@@ -151,18 +180,42 @@ class TypeContext:
f" Expression: {expr}"
f" Type Context: {self._target_type}"
)
expr.dtype = self._target_type
# endif
expr.dtype = self._target_type
@property
def target_type(self) -> PsType | None:
return self._target_type
def _compatible(self, dtype: PsType):
"""Checks whether the given data type is compatible with the context's target type.
If the target type is ``const``, they must be equal up to const qualification;
if the target type is not ``const``, `dtype` must match it exactly.
"""
assert self._target_type is not None
if self._target_type.const:
return constify(dtype) == self._target_type
else:
return dtype == self._target_type
def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None):
if self._require_nonconst:
if dtype.const:
if expr is None:
raise TypificationError(
f"Type mismatch: Encountered {dtype} in non-constant context."
)
else:
raise TypificationError(
f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context."
)
return dtype
else:
return constify(dtype)
class Typifier:
"""Apply data types to expressions.
**Contextual Typing**
The Typifier will traverse the AST and apply a contextual typing scheme to figure out
the data types of all encountered expressions.
To this end, it covers each expression tree with a set of disjoint typing contexts.
......@@ -183,6 +236,21 @@ class Typifier:
in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
as the target type is fixed.
**Typing Rules**
The following general rules apply:
- The context's `default_dtype` is applied to all untyped symbols
- By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
left-hand side
**Typing of symbol expressions**
Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
not necessarily their const-qualification.
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
but not vice versa.
"""
def __init__(self, ctx: KernelCreationContext):
......@@ -213,13 +281,21 @@ class Typifier:
for s in statements:
self.visit(s)
case PsAssignment(lhs, rhs):
case PsDeclaration(lhs, rhs):
tc = TypeContext()
# LHS defines target type; type context carries it to RHS
self.visit_expr(lhs, tc)
assert tc.target_type is not None
self.visit_expr(rhs, tc)
case PsAssignment(lhs, rhs):
tc_lhs = TypeContext(require_nonconst=True)
self.visit_expr(lhs, tc_lhs)
assert tc_lhs.target_type is not None
tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False)
self.visit_expr(rhs, tc_rhs)
case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType(const=True))
self.visit_expr(cond, cond_tc)
......@@ -233,10 +309,10 @@ class Typifier:
if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype)
tc = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc)
self.visit_expr(stop, tc)
self.visit_expr(step, tc)
tc_index = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc_index)
self.visit_expr(stop, tc_index)
self.visit_expr(step, tc_index)
self.visit(body)
......@@ -244,24 +320,35 @@ class Typifier:
raise NotImplementedError(f"Can't typify {node}")
def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
"""Recursive processing of expression nodes"""
"""Recursive processing of expression nodes.
This method opens, expands, and closes typing contexts according to the respective expression's
typing rules. It may add or check restrictions only when opening or closing a type context.
The actual type inference and checking during context expansion are performed by the methods
of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling
either ``apply_dtype`` or ``infer_dtype``.
"""
match expr:
case PsSymbolExpr(_):
if expr.dtype is None:
tc.apply_dtype(expr, self._ctx.default_dtype)
else:
tc.apply_dtype(expr, expr.dtype)
if expr.symbol.dtype is None:
expr.symbol.dtype = self._ctx.default_dtype
case PsConstantExpr(_):
tc.infer_dtype(expr)
tc.apply_dtype(expr.symbol.dtype, expr)
case PsConstantExpr(c):
if c.dtype is not None:
tc.apply_dtype(c.dtype, expr)
else:
tc.infer_dtype(expr)
case PsArrayAccess(bptr, idx):
tc.apply_dtype(expr, bptr.array.element_type)
tc.apply_dtype(bptr.array.element_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype)
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
......@@ -276,12 +363,12 @@ class Typifier:
"Type of subscript base is not subscriptable."
)
tc.apply_dtype(expr, arr_tc.target_type.base_type)
tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext()
self.visit_expr(idx, index_tc)
if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype)
index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError(
f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
......@@ -296,7 +383,7 @@ class Typifier:
"Type of argument to a Deref is not dereferencable"
)
tc.apply_dtype(expr, ptr_tc.target_type.base_type)
tc.apply_dtype(ptr_tc.target_type.base_type, expr)
case PsAddressOf(arg):
arg_tc = TypeContext()
......@@ -308,10 +395,11 @@ class Typifier:
)
ptr_type = PsPointerType(arg_tc.target_type, True)
tc.apply_dtype(expr, ptr_type)
tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None)
# Members of a struct type inherit the struct type's `const` qualifier
aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
self.visit_expr(aggr, aggr_tc)
aggr_type = aggr_tc.target_type
......@@ -326,7 +414,11 @@ class Typifier:
f"Aggregate of type {aggr_type} does not have a member {member}."
)
tc.apply_dtype(expr, member.dtype)
member_type = member.dtype
if aggr_type.const:
member_type = constify(member_type)
tc.apply_dtype(member_type, expr)
case PsBinOp(op1, op2):
self.visit_expr(op1, tc)
......@@ -365,14 +457,14 @@ class Typifier:
f"{len(items)} items as {tc.target_type}"
)
else:
items_tc.apply_dtype(None, tc.target_type.base_type)
items_tc.apply_dtype(tc.target_type.base_type)
else:
arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_dtype(expr, arr_type)
tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext())
tc.apply_dtype(expr, dtype)
tc.apply_dtype(dtype, expr)
case _:
raise NotImplementedError(f"Can't typify {expr}")
......@@ -6,11 +6,11 @@ from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType
from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression
from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
from pystencils.backend.constants import PsConstant
from pystencils.types import constify
from pystencils.types.quick import Fp, create_numeric_type
from pystencils.types.quick import Fp, create_type, create_numeric_type
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
......@@ -38,7 +38,7 @@ def test_typify_simple():
assert isinstance(fasm, PsDeclaration)
def check(expr):
assert expr.dtype == ctx.default_dtype
assert expr.dtype == constify(ctx.default_dtype)
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
......@@ -56,6 +56,89 @@ def test_typify_simple():
check(fasm.rhs)
def test_rhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Right-hand sides should always get const types
asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
assert asm.rhs.get_dtype().const
asm = typify(
freeze(
Assignment(
f.absolute_access([0], [0]),
f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
)
)
)
assert asm.rhs.get_dtype().const
def test_lhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Assignment RHS may not be const
asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
assert not asm.lhs.get_dtype().const
# Cannot assign to const left-hand side
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
struct_type = constify(create_type(np_struct))
struct_field = Field.create_generic(
"struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
)
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
# Const LHS is only OK in declarations
q = ctx.get_symbol("q", Fp(32, const=True))
ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
ast = typify(ast)
assert ast.lhs.dtype == Fp(32, const=True)
ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
with pytest.raises(TypificationError):
typify(ast)
def test_typify_structs():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
......@@ -70,6 +153,10 @@ def test_typify_structs():
fasm = freeze(asm)
fasm = typify(fasm)
asm = Assignment(f.absolute_access((0,), "data"), x)
fasm = freeze(asm)
fasm = typify(fasm)
# Bad
asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm)
......@@ -87,7 +174,7 @@ def test_contextual_typing():
expr = typify(expr)
def check(expr):
assert expr.dtype == ctx.default_dtype
assert expr.dtype == constify(ctx.default_dtype)
match expr:
case PsConstantExpr(cs):
assert cs.value in (2, 3, -4)
......@@ -199,6 +286,6 @@ def test_typify_constant_clones():
expr_clone = expr.clone()
expr = typify(expr)
assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment