Skip to content
Snippets Groups Projects
Commit 191cc207 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'fhennig/fix-const-typing' into 'backend-rework'

Fix handling of constness in Typifier

See merge request pycodegen/pystencils!372
parents 6c649d20 7f5ffb5b
No related branches found
No related tags found
No related merge requests found
...@@ -59,9 +59,14 @@ class PsConstant: ...@@ -59,9 +59,14 @@ class PsConstant:
@property @property
def dtype(self) -> PsNumericType | None: def dtype(self) -> PsNumericType | None:
"""This constant's data type, or ``None`` if it is untyped.
The data type of a constant always has ``const == True``.
"""
return self._dtype return self._dtype
def get_dtype(self) -> PsNumericType: def get_dtype(self) -> PsNumericType:
"""Retrieve this constant's data type, throwing an exception if the constant is untyped."""
if self._dtype is None: if self._dtype is None:
raise PsInternalCompilerError("Data type of constant was not set.") raise PsInternalCompilerError("Data type of constant was not set.")
return self._dtype return self._dtype
......
...@@ -12,7 +12,7 @@ from ...types import ( ...@@ -12,7 +12,7 @@ from ...types import (
PsDereferencableType, PsDereferencableType,
PsPointerType, PsPointerType,
PsBoolType, PsBoolType,
deconstify, constify,
) )
from ..ast.structural import ( from ..ast.structural import (
PsAstNode, PsAstNode,
...@@ -21,6 +21,7 @@ from ..ast.structural import ( ...@@ -21,6 +21,7 @@ from ..ast.structural import (
PsConditional, PsConditional,
PsExpression, PsExpression,
PsAssignment, PsAssignment,
PsDeclaration,
) )
from ..ast.expressions import ( from ..ast.expressions import (
PsArrayAccess, PsArrayAccess,
...@@ -54,20 +55,48 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) ...@@ -54,20 +55,48 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class TypeContext: class TypeContext:
def __init__(self, target_type: PsType | None = None): """Typing context, with support for type inference and checking.
self._target_type = deconstify(target_type) if target_type is not None else None
Instances of this class are used to propagate and check data types across expression subtrees
of the AST. Each type context has:
- A target type `target_type`, which shall be applied to all expressions it covers
- A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future.
"""
def __init__(
self, target_type: PsType | None = None, require_nonconst: bool = False
):
self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = [] self._deferred_exprs: list[PsExpression] = []
def apply_dtype(self, expr: PsExpression | None, dtype: PsType): self._target_type = (
"""Applies the given ``dtype`` to the given expression inside this type context. self._fix_constness(target_type) if target_type is not None else None
)
@property
def target_type(self) -> PsType | None:
return self._target_type
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
The given expression will be covered by this type context.
If the context's target_type is already known, it must be compatible with the given dtype. If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions. to all deferred expressions.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
""" """
dtype = deconstify(dtype) dtype = self._fix_constness(dtype)
if self._target_type is not None and dtype != self._target_type: if self._target_type is not None and dtype != self._target_type:
raise TypificationError( raise TypificationError(
...@@ -80,14 +109,7 @@ class TypeContext: ...@@ -80,14 +109,7 @@ class TypeContext:
self._propagate_target_type() self._propagate_target_type()
if expr is not None: if expr is not None:
if expr.dtype is None: self._apply_target_type(expr)
self._apply_target_type(expr)
elif deconstify(expr.dtype) != self._target_type:
raise TypificationError(
"Type conflict: Predefined expression type did not match the context's target type\n"
f" Expression type: {dtype}\n"
f" Target type: {self._target_type}"
)
def infer_dtype(self, expr: PsExpression): def infer_dtype(self, expr: PsExpression):
"""Infer the data type for the given expression. """Infer the data type for the given expression.
...@@ -96,7 +118,8 @@ class TypeContext: ...@@ -96,7 +118,8 @@ class TypeContext:
Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
called on this context. called on this context.
If the expression already has a data type set, it must be equal to the inferred type. If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
""" """
if self._target_type is None: if self._target_type is None:
...@@ -113,7 +136,7 @@ class TypeContext: ...@@ -113,7 +136,7 @@ class TypeContext:
assert self._target_type is not None assert self._target_type is not None
if expr.dtype is not None: if expr.dtype is not None:
if deconstify(expr.dtype) != self.target_type: if not self._compatible(expr.dtype):
raise TypificationError( raise TypificationError(
f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n"
f" Expression type: {expr.dtype}\n" f" Expression type: {expr.dtype}\n"
...@@ -128,7 +151,7 @@ class TypeContext: ...@@ -128,7 +151,7 @@ class TypeContext:
) )
if c.dtype is None: if c.dtype is None:
expr.constant = c.interpret_as(self._target_type) expr.constant = c.interpret_as(self._target_type)
elif deconstify(c.dtype) != self._target_type: elif not self._compatible(c.dtype):
raise TypificationError( raise TypificationError(
f"Type mismatch at constant {c}: Constant type did not match the context's target type\n" f"Type mismatch at constant {c}: Constant type did not match the context's target type\n"
f" Constant type: {c.dtype}\n" f" Constant type: {c.dtype}\n"
...@@ -136,7 +159,13 @@ class TypeContext: ...@@ -136,7 +159,13 @@ class TypeContext:
) )
case PsSymbolExpr(symb): case PsSymbolExpr(symb):
symb.apply_dtype(self._target_type) assert symb.dtype is not None
if not self._compatible(symb.dtype):
raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n"
f" Target type: {self._target_type}"
)
case ( case (
PsIntDiv() PsIntDiv()
...@@ -151,18 +180,42 @@ class TypeContext: ...@@ -151,18 +180,42 @@ class TypeContext:
f" Expression: {expr}" f" Expression: {expr}"
f" Type Context: {self._target_type}" f" Type Context: {self._target_type}"
) )
expr.dtype = self._target_type
# endif # endif
expr.dtype = self._target_type
@property def _compatible(self, dtype: PsType):
def target_type(self) -> PsType | None: """Checks whether the given data type is compatible with the context's target type.
return self._target_type
If the target type is ``const``, they must be equal up to const qualification;
if the target type is not ``const``, `dtype` must match it exactly.
"""
assert self._target_type is not None
if self._target_type.const:
return constify(dtype) == self._target_type
else:
return dtype == self._target_type
def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None):
if self._require_nonconst:
if dtype.const:
if expr is None:
raise TypificationError(
f"Type mismatch: Encountered {dtype} in non-constant context."
)
else:
raise TypificationError(
f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context."
)
return dtype
else:
return constify(dtype)
class Typifier: class Typifier:
"""Apply data types to expressions. """Apply data types to expressions.
**Contextual Typing**
The Typifier will traverse the AST and apply a contextual typing scheme to figure out The Typifier will traverse the AST and apply a contextual typing scheme to figure out
the data types of all encountered expressions. the data types of all encountered expressions.
To this end, it covers each expression tree with a set of disjoint typing contexts. To this end, it covers each expression tree with a set of disjoint typing contexts.
...@@ -183,6 +236,21 @@ class Typifier: ...@@ -183,6 +236,21 @@ class Typifier:
in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
as the target type is fixed. as the target type is fixed.
**Typing Rules**
The following general rules apply:
- The context's `default_dtype` is applied to all untyped symbols
- By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
left-hand side
**Typing of symbol expressions**
Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
not necessarily their const-qualification.
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
but not vice versa.
""" """
def __init__(self, ctx: KernelCreationContext): def __init__(self, ctx: KernelCreationContext):
...@@ -213,13 +281,21 @@ class Typifier: ...@@ -213,13 +281,21 @@ class Typifier:
for s in statements: for s in statements:
self.visit(s) self.visit(s)
case PsAssignment(lhs, rhs): case PsDeclaration(lhs, rhs):
tc = TypeContext() tc = TypeContext()
# LHS defines target type; type context carries it to RHS # LHS defines target type; type context carries it to RHS
self.visit_expr(lhs, tc) self.visit_expr(lhs, tc)
assert tc.target_type is not None assert tc.target_type is not None
self.visit_expr(rhs, tc) self.visit_expr(rhs, tc)
case PsAssignment(lhs, rhs):
tc_lhs = TypeContext(require_nonconst=True)
self.visit_expr(lhs, tc_lhs)
assert tc_lhs.target_type is not None
tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False)
self.visit_expr(rhs, tc_rhs)
case PsConditional(cond, branch_true, branch_false): case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType(const=True)) cond_tc = TypeContext(PsBoolType(const=True))
self.visit_expr(cond, cond_tc) self.visit_expr(cond, cond_tc)
...@@ -233,10 +309,10 @@ class Typifier: ...@@ -233,10 +309,10 @@ class Typifier:
if ctr.symbol.dtype is None: if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype) ctr.symbol.apply_dtype(self._ctx.index_dtype)
tc = TypeContext(ctr.symbol.dtype) tc_index = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc) self.visit_expr(start, tc_index)
self.visit_expr(stop, tc) self.visit_expr(stop, tc_index)
self.visit_expr(step, tc) self.visit_expr(step, tc_index)
self.visit(body) self.visit(body)
...@@ -244,24 +320,35 @@ class Typifier: ...@@ -244,24 +320,35 @@ class Typifier:
raise NotImplementedError(f"Can't typify {node}") raise NotImplementedError(f"Can't typify {node}")
def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None: def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
"""Recursive processing of expression nodes""" """Recursive processing of expression nodes.
This method opens, expands, and closes typing contexts according to the respective expression's
typing rules. It may add or check restrictions only when opening or closing a type context.
The actual type inference and checking during context expansion are performed by the methods
of `TypeContext`. ``visit_expr`` tells the typing context how to handle an expression by calling
either ``apply_dtype`` or ``infer_dtype``.
"""
match expr: match expr:
case PsSymbolExpr(_): case PsSymbolExpr(_):
if expr.dtype is None: if expr.symbol.dtype is None:
tc.apply_dtype(expr, self._ctx.default_dtype) expr.symbol.dtype = self._ctx.default_dtype
else:
tc.apply_dtype(expr, expr.dtype)
case PsConstantExpr(_): tc.apply_dtype(expr.symbol.dtype, expr)
tc.infer_dtype(expr)
case PsConstantExpr(c):
if c.dtype is not None:
tc.apply_dtype(c.dtype, expr)
else:
tc.infer_dtype(expr)
case PsArrayAccess(bptr, idx): case PsArrayAccess(bptr, idx):
tc.apply_dtype(expr, bptr.array.element_type) tc.apply_dtype(bptr.array.element_type, expr)
index_tc = TypeContext() index_tc = TypeContext()
self.visit_expr(idx, index_tc) self.visit_expr(idx, index_tc)
if index_tc.target_type is None: if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype) index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType): elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError( raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}" f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
...@@ -276,12 +363,12 @@ class Typifier: ...@@ -276,12 +363,12 @@ class Typifier:
"Type of subscript base is not subscriptable." "Type of subscript base is not subscriptable."
) )
tc.apply_dtype(expr, arr_tc.target_type.base_type) tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext() index_tc = TypeContext()
self.visit_expr(idx, index_tc) self.visit_expr(idx, index_tc)
if index_tc.target_type is None: if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype) index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType): elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError( raise TypificationError(
f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
...@@ -296,7 +383,7 @@ class Typifier: ...@@ -296,7 +383,7 @@ class Typifier:
"Type of argument to a Deref is not dereferencable" "Type of argument to a Deref is not dereferencable"
) )
tc.apply_dtype(expr, ptr_tc.target_type.base_type) tc.apply_dtype(ptr_tc.target_type.base_type, expr)
case PsAddressOf(arg): case PsAddressOf(arg):
arg_tc = TypeContext() arg_tc = TypeContext()
...@@ -308,10 +395,11 @@ class Typifier: ...@@ -308,10 +395,11 @@ class Typifier:
) )
ptr_type = PsPointerType(arg_tc.target_type, True) ptr_type = PsPointerType(arg_tc.target_type, True)
tc.apply_dtype(expr, ptr_type) tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name): case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None) # Members of a struct type inherit the struct type's `const` qualifier
aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
self.visit_expr(aggr, aggr_tc) self.visit_expr(aggr, aggr_tc)
aggr_type = aggr_tc.target_type aggr_type = aggr_tc.target_type
...@@ -326,7 +414,11 @@ class Typifier: ...@@ -326,7 +414,11 @@ class Typifier:
f"Aggregate of type {aggr_type} does not have a member {member}." f"Aggregate of type {aggr_type} does not have a member {member}."
) )
tc.apply_dtype(expr, member.dtype) member_type = member.dtype
if aggr_type.const:
member_type = constify(member_type)
tc.apply_dtype(member_type, expr)
case PsBinOp(op1, op2): case PsBinOp(op1, op2):
self.visit_expr(op1, tc) self.visit_expr(op1, tc)
...@@ -365,14 +457,14 @@ class Typifier: ...@@ -365,14 +457,14 @@ class Typifier:
f"{len(items)} items as {tc.target_type}" f"{len(items)} items as {tc.target_type}"
) )
else: else:
items_tc.apply_dtype(None, tc.target_type.base_type) items_tc.apply_dtype(tc.target_type.base_type)
else: else:
arr_type = PsArrayType(items_tc.target_type, len(items)) arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_dtype(expr, arr_type) tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg): case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext()) self.visit_expr(arg, TypeContext())
tc.apply_dtype(expr, dtype) tc.apply_dtype(dtype, expr)
case _: case _:
raise NotImplementedError(f"Can't typify {expr}") raise NotImplementedError(f"Can't typify {expr}")
...@@ -6,11 +6,11 @@ from typing import cast ...@@ -6,11 +6,11 @@ from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils import Assignment, TypedSymbol, Field, FieldType
from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression
from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.types import constify from pystencils.types import constify
from pystencils.types.quick import Fp, create_numeric_type from pystencils.types.quick import Fp, create_type, create_numeric_type
from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
...@@ -38,7 +38,7 @@ def test_typify_simple(): ...@@ -38,7 +38,7 @@ def test_typify_simple():
assert isinstance(fasm, PsDeclaration) assert isinstance(fasm, PsDeclaration)
def check(expr): def check(expr):
assert expr.dtype == ctx.default_dtype assert expr.dtype == constify(ctx.default_dtype)
match expr: match expr:
case PsConstantExpr(cs): case PsConstantExpr(cs):
assert cs.value == 2 assert cs.value == 2
...@@ -56,6 +56,89 @@ def test_typify_simple(): ...@@ -56,6 +56,89 @@ def test_typify_simple():
check(fasm.rhs) check(fasm.rhs)
def test_rhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Right-hand sides should always get const types
asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
assert asm.rhs.get_dtype().const
asm = typify(
freeze(
Assignment(
f.absolute_access([0], [0]),
f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
)
)
)
assert asm.rhs.get_dtype().const
def test_lhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Assignment RHS may not be const
asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
assert not asm.lhs.get_dtype().const
# Cannot assign to const left-hand side
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
struct_type = constify(create_type(np_struct))
struct_field = Field.create_generic(
"struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
)
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
# Const LHS is only OK in declarations
q = ctx.get_symbol("q", Fp(32, const=True))
ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
ast = typify(ast)
assert ast.lhs.dtype == Fp(32, const=True)
ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
with pytest.raises(TypificationError):
typify(ast)
def test_typify_structs(): def test_typify_structs():
ctx = KernelCreationContext(default_dtype=Fp(32)) ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx) freeze = FreezeExpressions(ctx)
...@@ -70,6 +153,10 @@ def test_typify_structs(): ...@@ -70,6 +153,10 @@ def test_typify_structs():
fasm = freeze(asm) fasm = freeze(asm)
fasm = typify(fasm) fasm = typify(fasm)
asm = Assignment(f.absolute_access((0,), "data"), x)
fasm = freeze(asm)
fasm = typify(fasm)
# Bad # Bad
asm = Assignment(x, f.absolute_access((0,), "size")) asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm) fasm = freeze(asm)
...@@ -87,7 +174,7 @@ def test_contextual_typing(): ...@@ -87,7 +174,7 @@ def test_contextual_typing():
expr = typify(expr) expr = typify(expr)
def check(expr): def check(expr):
assert expr.dtype == ctx.default_dtype assert expr.dtype == constify(ctx.default_dtype)
match expr: match expr:
case PsConstantExpr(cs): case PsConstantExpr(cs):
assert cs.value in (2, 3, -4) assert cs.value in (2, 3, -4)
...@@ -199,6 +286,6 @@ def test_typify_constant_clones(): ...@@ -199,6 +286,6 @@ def test_typify_constant_clones():
expr_clone = expr.clone() expr_clone = expr.clone()
expr = typify(expr) expr = typify(expr)
assert expr_clone.operand1.dtype is None assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment