Skip to content
Snippets Groups Projects
Commit b305ce9a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Change symbol typificaiton: Infer unknown LHS data types from RHS

parent db6a4c59
No related branches found
No related tags found
1 merge request!407Change symbol typification: Infer unknown LHS data types from RHS
Pipeline #67715 passed
...@@ -12,7 +12,7 @@ from ...sympyextensions.typed_sympy import TypedSymbol ...@@ -12,7 +12,7 @@ from ...sympyextensions.typed_sympy import TypedSymbol
from ..symbols import PsSymbol from ..symbols import PsSymbol
from ..arrays import PsLinearizedArray from ..arrays import PsLinearizedArray
from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify
from ..constraints import KernelParamsConstraint from ..constraints import KernelParamsConstraint
from ..exceptions import PsInternalCompilerError, KernelConstraintsError from ..exceptions import PsInternalCompilerError, KernelConstraintsError
...@@ -63,8 +63,8 @@ class KernelCreationContext: ...@@ -63,8 +63,8 @@ class KernelCreationContext:
default_dtype: PsNumericType = DEFAULTS.numeric_dtype, default_dtype: PsNumericType = DEFAULTS.numeric_dtype,
index_dtype: PsIntegerType = DEFAULTS.index_dtype, index_dtype: PsIntegerType = DEFAULTS.index_dtype,
): ):
self._default_dtype = default_dtype self._default_dtype = deconstify(default_dtype)
self._index_dtype = index_dtype self._index_dtype = deconstify(index_dtype)
self._symbols: dict[str, PsSymbol] = dict() self._symbols: dict[str, PsSymbol] = dict()
......
...@@ -68,10 +68,18 @@ class TypeContext: ...@@ -68,10 +68,18 @@ class TypeContext:
- A set of restrictions on the target type: - A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides - `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future. - Additional restrictions may be added in the future.
The type context also tracks the tree traversal of the typifier:
- ``is_lhs`` is set to True while a left-hand side expression is being processed,
and `False` while a right-hand side expression is processed.
""" """
def __init__( def __init__(
self, target_type: PsType | None = None, require_nonconst: bool = False self,
target_type: PsType | None = None,
require_nonconst: bool = False,
is_lhs: bool = False
): ):
self._require_nonconst = require_nonconst self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = [] self._deferred_exprs: list[PsExpression] = []
...@@ -80,6 +88,8 @@ class TypeContext: ...@@ -80,6 +88,8 @@ class TypeContext:
self._fix_constness(target_type) if target_type is not None else None self._fix_constness(target_type) if target_type is not None else None
) )
self._is_lhs = is_lhs
@property @property
def target_type(self) -> PsType | None: def target_type(self) -> PsType | None:
return self._target_type return self._target_type
...@@ -88,6 +98,14 @@ class TypeContext: ...@@ -88,6 +98,14 @@ class TypeContext:
def require_nonconst(self) -> bool: def require_nonconst(self) -> bool:
return self._require_nonconst return self._require_nonconst
@property
def is_lhs(self) -> bool:
return self._is_lhs
@is_lhs.setter
def is_lhs(self, value: bool):
self._is_lhs = value
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None): def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression. """Applies the given ``dtype`` to this type context, and optionally to the given expression.
...@@ -171,8 +189,10 @@ class TypeContext: ...@@ -171,8 +189,10 @@ class TypeContext:
) )
case PsSymbolExpr(symb): case PsSymbolExpr(symb):
assert symb.dtype is not None if symb.dtype is None:
if not self._compatible(symb.dtype): # Symbols are not forced to constness
symb.dtype = deconstify(self._target_type)
elif not self._compatible(symb.dtype):
raise TypificationError( raise TypificationError(
f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n"
f" Symbol type: {symb.dtype}\n" f" Symbol type: {symb.dtype}\n"
...@@ -262,7 +282,11 @@ class Typifier: ...@@ -262,7 +282,11 @@ class Typifier:
The following general rules apply: The following general rules apply:
- The context's `default_dtype` is applied to all untyped symbols - The context's ``default_dtype`` is applied to all untyped symbols encountered inside a right-hand side expression
- If an untyped symbol is encountered on an assignment's left-hand side, it will first be attempted to infer its
type from the right-hand side. If that fails, the context's ``default_dtype`` will be applied.
- It is an error if an untyped symbol occurs in the same type context as a typed symbol or constant
with a non-default data type.
- By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's - By default, all expressions receive a ``const`` type unless they occur on a (non-declaration) assignment's
left-hand side left-hand side
...@@ -280,7 +304,12 @@ class Typifier: ...@@ -280,7 +304,12 @@ class Typifier:
def __call__(self, node: NodeT) -> NodeT: def __call__(self, node: NodeT) -> NodeT:
if isinstance(node, PsExpression): if isinstance(node, PsExpression):
self.visit_expr(node, TypeContext()) tc = TypeContext()
self.visit_expr(node, tc)
if tc.target_type is None:
# no type could be inferred -> take the default
tc.apply_dtype(self._ctx.default_dtype)
else: else:
self.visit(node) self.visit(node)
return node return node
...@@ -305,19 +334,31 @@ class Typifier: ...@@ -305,19 +334,31 @@ class Typifier:
case PsDeclaration(lhs, rhs): case PsDeclaration(lhs, rhs):
tc = TypeContext() tc = TypeContext()
# LHS defines target type; type context carries it to RHS
tc.is_lhs = True
self.visit_expr(lhs, tc) self.visit_expr(lhs, tc)
assert tc.target_type is not None tc.is_lhs = False
self.visit_expr(rhs, tc) self.visit_expr(rhs, tc)
if tc.target_type is None:
# no type has been inferred -> use the default dtype
tc.apply_dtype(self._ctx.default_dtype)
case PsAssignment(lhs, rhs): case PsAssignment(lhs, rhs):
tc_lhs = TypeContext(require_nonconst=True) tc_lhs = TypeContext(require_nonconst=True, is_lhs=True)
self.visit_expr(lhs, tc_lhs) self.visit_expr(lhs, tc_lhs)
assert tc_lhs.target_type is not None
tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False) tc_rhs = TypeContext(target_type=tc_lhs.target_type)
self.visit_expr(rhs, tc_rhs) self.visit_expr(rhs, tc_rhs)
if tc_rhs.target_type is None:
tc_rhs.apply_dtype(self._ctx.default_dtype)
if tc_lhs.target_type is None:
assert tc_rhs.target_type is not None
tc_lhs.apply_dtype(deconstify(tc_rhs.target_type))
case PsConditional(cond, branch_true, branch_false): case PsConditional(cond, branch_true, branch_false):
cond_tc = TypeContext(PsBoolType()) cond_tc = TypeContext(PsBoolType())
self.visit_expr(cond, cond_tc) self.visit_expr(cond, cond_tc)
...@@ -330,6 +371,7 @@ class Typifier: ...@@ -330,6 +371,7 @@ class Typifier:
case PsLoop(ctr, start, stop, step, body): case PsLoop(ctr, start, stop, step, body):
if ctr.symbol.dtype is None: if ctr.symbol.dtype is None:
ctr.symbol.apply_dtype(self._ctx.index_dtype) ctr.symbol.apply_dtype(self._ctx.index_dtype)
ctr.dtype = ctr.symbol.get_dtype()
tc_index = TypeContext(ctr.symbol.dtype) tc_index = TypeContext(ctr.symbol.dtype)
self.visit_expr(start, tc_index) self.visit_expr(start, tc_index)
...@@ -355,11 +397,16 @@ class Typifier: ...@@ -355,11 +397,16 @@ class Typifier:
either ``apply_dtype`` or ``infer_dtype``. either ``apply_dtype`` or ``infer_dtype``.
""" """
match expr: match expr:
case PsSymbolExpr(_): case PsSymbolExpr(symb):
if expr.symbol.dtype is None: if tc.is_lhs:
expr.symbol.dtype = self._ctx.default_dtype if symb.dtype is not None:
tc.apply_dtype(symb.dtype, expr)
tc.apply_dtype(expr.symbol.dtype, expr) elif tc.is_lhs:
tc.infer_dtype(expr)
else:
if symb.dtype is None:
symb.dtype = self._ctx.default_dtype
tc.apply_dtype(symb.dtype, expr)
case PsConstantExpr(c): case PsConstantExpr(c):
if c.dtype is not None: if c.dtype is not None:
......
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
from typing import cast from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
from pystencils.backend.ast.structural import ( from pystencils.backend.ast.structural import (
PsDeclaration, PsDeclaration,
...@@ -186,7 +186,7 @@ def test_typify_structs(): ...@@ -186,7 +186,7 @@ def test_typify_structs():
fasm = typify(fasm) fasm = typify(fasm)
def test_contextual_typing(): def test_default_typing():
ctx = KernelCreationContext() ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx) freeze = FreezeExpressions(ctx)
typify = Typifier(ctx) typify = Typifier(ctx)
...@@ -213,6 +213,43 @@ def test_contextual_typing(): ...@@ -213,6 +213,43 @@ def test_contextual_typing():
check(expr) check(expr)
def test_lhs_inference():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
# Type of the LHS is propagated to untyped RHS symbols
asm = Assignment(x, 3 - q)
fasm = typify(freeze(asm))
assert ctx.get_symbol("x").dtype == Fp(32)
assert fasm.lhs.dtype == constify(Fp(32))
asm = Assignment(y, 3 - w)
fasm = typify(freeze(asm))
assert ctx.get_symbol("y").dtype == Fp(16)
assert fasm.lhs.dtype == constify(Fp(16))
fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
fasm = typify(fasm)
assert ctx.get_symbol("z").dtype == Fp(16)
assert fasm.lhs.dtype == Fp(16)
fasm = PsDeclaration(PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q)))
fasm = typify(fasm)
assert ctx.get_symbol("r").dtype == Bool()
assert fasm.lhs.dtype == constify(Bool())
assert fasm.rhs.dtype == constify(Bool())
def test_erronous_typing(): def test_erronous_typing():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx) freeze = FreezeExpressions(ctx)
...@@ -227,16 +264,23 @@ def test_erronous_typing(): ...@@ -227,16 +264,23 @@ def test_erronous_typing():
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
typify(expr) typify(expr)
# Conflict between LHS and RHS symbols
asm = Assignment(q, 3 - w) asm = Assignment(q, 3 - w)
fasm = freeze(asm) fasm = freeze(asm)
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
typify(fasm) typify(fasm)
# Do not propagate types back from LHS symbols to RHS symbols
asm = Assignment(q, 3 - x) asm = Assignment(q, 3 - x)
fasm = freeze(asm) fasm = freeze(asm)
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
typify(fasm) typify(fasm)
asm = AddAugmentedAssignment(z, 3 - q)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
def test_typify_integer_binops(): def test_typify_integer_binops():
ctx = KernelCreationContext() ctx = KernelCreationContext()
......
...@@ -51,9 +51,6 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): ...@@ -51,9 +51,6 @@ def add_fixed_constant_boundary_handling(assignments, with_cse):
@pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('dtype', ('float64', 'float32'))
@pytest.mark.parametrize('with_cse', (False, 'with_cse')) @pytest.mark.parametrize('with_cse', (False, 'with_cse'))
def test_boundary_check(dtype, with_cse): def test_boundary_check(dtype, with_cse):
if with_cse:
pytest.xfail("Doesn't typify correctly yet.")
f, g = ps.fields(f"f, g : {dtype}[2D]") f, g = ps.fields(f"f, g : {dtype}[2D]")
stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment