Skip to content
Snippets Groups Projects
Commit f1486c07 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Typing fixes:

 - Symbols now also receive the const-qualification of `default_dtype`
 - Fix const propagation to structs
 - Add doc comment explaining all this
parent afdd0127
No related branches found
No related tags found
2 merge requests!373Symbol Canonicalization, Loop-Invariant Code Motion, and AST Factory,!372Fix handling of constness in Typifier
...@@ -65,13 +65,24 @@ class TypeContext: ...@@ -65,13 +65,24 @@ class TypeContext:
self._fix_constness(target_type) if target_type is not None else None self._fix_constness(target_type) if target_type is not None else None
) )
def apply_dtype(self, expr: PsExpression | None, dtype: PsType): @property
"""Applies the given ``dtype`` to the given expression inside this type context. def target_type(self) -> PsType | None:
return self._target_type
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
The given expression will be covered by this type context.
If the context's target_type is already known, it must be compatible with the given dtype. If the context's target_type is already known, it must be compatible with the given dtype.
If the target type is still unknown, target_type is set to dtype and retroactively applied If the target type is still unknown, target_type is set to dtype and retroactively applied
to all deferred expressions. to all deferred expressions.
If an expression is specified, it will be covered by the type context.
If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
""" """
dtype = self._fix_constness(dtype) dtype = self._fix_constness(dtype)
...@@ -96,7 +107,8 @@ class TypeContext: ...@@ -96,7 +107,8 @@ class TypeContext:
Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is Otherwise, the expression is deferred, and a type will be applied to it as soon as `apply_type` is
called on this context. called on this context.
If the expression already has a data type set, it must be equal to the inferred type. If the expression already has a data type set, it must be compatible with the target type
and will be replaced by it.
""" """
if self._target_type is None: if self._target_type is None:
...@@ -136,9 +148,8 @@ class TypeContext: ...@@ -136,9 +148,8 @@ class TypeContext:
) )
case PsSymbolExpr(symb): case PsSymbolExpr(symb):
if symb.dtype is None: assert symb.dtype is not None
symb.dtype = self._target_type if not self._compatible(symb.dtype):
elif not self._compatible(symb.dtype):
raise TypificationError( raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n" f" Symbol type: {symb.dtype}\n"
...@@ -183,14 +194,12 @@ class TypeContext: ...@@ -183,14 +194,12 @@ class TypeContext:
else: else:
return constify(dtype) return constify(dtype)
@property
def target_type(self) -> PsType | None:
return self._target_type
class Typifier: class Typifier:
"""Apply data types to expressions. """Apply data types to expressions.
**Contextual Typing**
The Typifier will traverse the AST and apply a contextual typing scheme to figure out The Typifier will traverse the AST and apply a contextual typing scheme to figure out
the data types of all encountered expressions. the data types of all encountered expressions.
To this end, it covers each expression tree with a set of disjoint typing contexts. To this end, it covers each expression tree with a set of disjoint typing contexts.
...@@ -211,6 +220,21 @@ class Typifier: ...@@ -211,6 +220,21 @@ class Typifier:
in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon in the context, the expression is deferred by storing it in the context, and will be assigned a type as soon
as the target type is fixed. as the target type is fixed.
**Typing Rules**
The following rules apply:
- The context's `default_dtype` is applied to all untyped symbols
- By default, all expressions receive a ``const`` type unless otherwise required
- The left-hand side of any non-declaration assignment must not be ``const``
**Typing of symbol expressions**
Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
not necessarily their const-qualification.
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
but not vice versa.
""" """
def __init__(self, ctx: KernelCreationContext): def __init__(self, ctx: KernelCreationContext):
...@@ -284,20 +308,23 @@ class Typifier: ...@@ -284,20 +308,23 @@ class Typifier:
match expr: match expr:
case PsSymbolExpr(_): case PsSymbolExpr(_):
if expr.symbol.dtype is None: if expr.symbol.dtype is None:
tc.apply_dtype(expr, self._ctx.default_dtype) expr.symbol.dtype = self._ctx.default_dtype
else:
tc.apply_dtype(expr, expr.symbol.dtype)
case PsConstantExpr(_): tc.apply_dtype(expr.symbol.dtype, expr)
tc.infer_dtype(expr)
case PsConstantExpr(c):
if c.dtype is not None:
tc.apply_dtype(c.dtype, expr)
else:
tc.infer_dtype(expr)
case PsArrayAccess(bptr, idx): case PsArrayAccess(bptr, idx):
tc.apply_dtype(expr, bptr.array.element_type) tc.apply_dtype(bptr.array.element_type, expr)
index_tc = TypeContext() index_tc = TypeContext()
self.visit_expr(idx, index_tc) self.visit_expr(idx, index_tc)
if index_tc.target_type is None: if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype) index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType): elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError( raise TypificationError(
f"Array index is not of integer type: {idx} has type {index_tc.target_type}" f"Array index is not of integer type: {idx} has type {index_tc.target_type}"
...@@ -312,12 +339,12 @@ class Typifier: ...@@ -312,12 +339,12 @@ class Typifier:
"Type of subscript base is not subscriptable." "Type of subscript base is not subscriptable."
) )
tc.apply_dtype(expr, arr_tc.target_type.base_type) tc.apply_dtype(arr_tc.target_type.base_type, expr)
index_tc = TypeContext() index_tc = TypeContext()
self.visit_expr(idx, index_tc) self.visit_expr(idx, index_tc)
if index_tc.target_type is None: if index_tc.target_type is None:
index_tc.apply_dtype(idx, self._ctx.index_dtype) index_tc.apply_dtype(self._ctx.index_dtype, idx)
elif not isinstance(index_tc.target_type, PsIntegerType): elif not isinstance(index_tc.target_type, PsIntegerType):
raise TypificationError( raise TypificationError(
f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}" f"Subscript index is not of integer type: {idx} has type {index_tc.target_type}"
...@@ -332,7 +359,7 @@ class Typifier: ...@@ -332,7 +359,7 @@ class Typifier:
"Type of argument to a Deref is not dereferencable" "Type of argument to a Deref is not dereferencable"
) )
tc.apply_dtype(expr, ptr_tc.target_type.base_type) tc.apply_dtype(ptr_tc.target_type.base_type, expr)
case PsAddressOf(arg): case PsAddressOf(arg):
arg_tc = TypeContext() arg_tc = TypeContext()
...@@ -344,10 +371,11 @@ class Typifier: ...@@ -344,10 +371,11 @@ class Typifier:
) )
ptr_type = PsPointerType(arg_tc.target_type, True) ptr_type = PsPointerType(arg_tc.target_type, True)
tc.apply_dtype(expr, ptr_type) tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name): case PsLookup(aggr, member_name):
aggr_tc = TypeContext(None) # Members of a struct type inherit the struct type's `const` qualifier
aggr_tc = TypeContext(None, require_nonconst=tc.require_nonconst)
self.visit_expr(aggr, aggr_tc) self.visit_expr(aggr, aggr_tc)
aggr_type = aggr_tc.target_type aggr_type = aggr_tc.target_type
...@@ -361,12 +389,12 @@ class Typifier: ...@@ -361,12 +389,12 @@ class Typifier:
raise TypificationError( raise TypificationError(
f"Aggregate of type {aggr_type} does not have a member {member}." f"Aggregate of type {aggr_type} does not have a member {member}."
) )
member_type = member.dtype member_type = member.dtype
if aggr_type.const: if aggr_type.const:
member_type = constify(member_type) member_type = constify(member_type)
tc.apply_dtype(expr, member_type) tc.apply_dtype(member_type, expr)
case PsBinOp(op1, op2): case PsBinOp(op1, op2):
self.visit_expr(op1, tc) self.visit_expr(op1, tc)
...@@ -405,14 +433,14 @@ class Typifier: ...@@ -405,14 +433,14 @@ class Typifier:
f"{len(items)} items as {tc.target_type}" f"{len(items)} items as {tc.target_type}"
) )
else: else:
items_tc.apply_dtype(None, tc.target_type.base_type) items_tc.apply_dtype(tc.target_type.base_type)
else: else:
arr_type = PsArrayType(items_tc.target_type, len(items)) arr_type = PsArrayType(items_tc.target_type, len(items))
tc.apply_dtype(expr, arr_type) tc.apply_dtype(arr_type, expr)
case PsCast(dtype, arg): case PsCast(dtype, arg):
self.visit_expr(arg, TypeContext()) self.visit_expr(arg, TypeContext())
tc.apply_dtype(expr, dtype) tc.apply_dtype(dtype, expr)
case _: case _:
raise NotImplementedError(f"Can't typify {expr}") raise NotImplementedError(f"Can't typify {expr}")
...@@ -45,7 +45,7 @@ def test_typify_simple(): ...@@ -45,7 +45,7 @@ def test_typify_simple():
assert cs.dtype == constify(ctx.default_dtype) assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb): case PsSymbolExpr(symb):
assert symb.name in "xyz" assert symb.name in "xyz"
assert symb.dtype == constify(ctx.default_dtype) assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2): case PsBinOp(op1, op2):
check(op1) check(op1)
check(op2) check(op2)
...@@ -109,7 +109,6 @@ def test_lhs_constness(): ...@@ -109,7 +109,6 @@ def test_lhs_constness():
) )
x, y, z = sp.symbols("x, y, z") x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", Fp(32, const=True))
# Assignment RHS may not be const # Assignment RHS may not be const
asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
...@@ -119,9 +118,6 @@ def test_lhs_constness(): ...@@ -119,9 +118,6 @@ def test_lhs_constness():
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y))) _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
# with pytest.raises(TypificationError):
# _ = typify(freeze(Assignment(q, x - z)))
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
struct_type = constify(create_type(np_struct)) struct_type = constify(create_type(np_struct))
struct_field = Field.create_generic( struct_field = Field.create_generic(
...@@ -146,6 +142,10 @@ def test_typify_structs(): ...@@ -146,6 +142,10 @@ def test_typify_structs():
fasm = freeze(asm) fasm = freeze(asm)
fasm = typify(fasm) fasm = typify(fasm)
asm = Assignment(f.absolute_access((0,), "data"), x)
fasm = freeze(asm)
fasm = typify(fasm)
# Bad # Bad
asm = Assignment(x, f.absolute_access((0,), "size")) asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm) fasm = freeze(asm)
...@@ -170,7 +170,7 @@ def test_contextual_typing(): ...@@ -170,7 +170,7 @@ def test_contextual_typing():
assert cs.dtype == constify(ctx.default_dtype) assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb): case PsSymbolExpr(symb):
assert symb.name in "xyz" assert symb.name in "xyz"
assert symb.dtype == constify(ctx.default_dtype) assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2): case PsBinOp(op1, op2):
check(op1) check(op1)
check(op2) check(op2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment