Skip to content
Snippets Groups Projects
Commit fcbdf0f5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Be more conservative about what is considered an LHS symbol

parent b305ce9a
No related branches found
No related tags found
1 merge request!407Change symbol typification: Infer unknown LHS data types from RHS
Pipeline #67716 passed
......@@ -68,18 +68,12 @@ class TypeContext:
- A set of restrictions on the target type:
- `require_nonconst` to make sure the target type is not `const`, as required on assignment left-hand sides
- Additional restrictions may be added in the future.
The type context also tracks the tree traversal of the typifier:
- ``is_lhs`` is set to True while a left-hand side expression is being processed,
and `False` while a right-hand side expression is processed.
"""
def __init__(
self,
target_type: PsType | None = None,
require_nonconst: bool = False,
is_lhs: bool = False
):
self._require_nonconst = require_nonconst
self._deferred_exprs: list[PsExpression] = []
......@@ -88,8 +82,6 @@ class TypeContext:
self._fix_constness(target_type) if target_type is not None else None
)
self._is_lhs = is_lhs
@property
def target_type(self) -> PsType | None:
return self._target_type
......@@ -97,14 +89,6 @@ class TypeContext:
@property
def require_nonconst(self) -> bool:
return self._require_nonconst
@property
def is_lhs(self) -> bool:
return self._is_lhs
@is_lhs.setter
def is_lhs(self, value: bool):
self._is_lhs = value
def apply_dtype(self, dtype: PsType, expr: PsExpression | None = None):
"""Applies the given ``dtype`` to this type context, and optionally to the given expression.
......@@ -333,29 +317,41 @@ class Typifier:
self.visit(s)
case PsDeclaration(lhs, rhs):
# Only if the LHS is an untyped symbol, infer its type from the RHS
infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
tc = TypeContext()
tc.is_lhs = True
self.visit_expr(lhs, tc)
tc.is_lhs = False
if infer_lhs:
tc.infer_dtype(lhs)
else:
self.visit_expr(lhs, tc)
assert tc.target_type is not None
self.visit_expr(rhs, tc)
if tc.target_type is None:
if infer_lhs and tc.target_type is None:
# no type has been inferred -> use the default dtype
tc.apply_dtype(self._ctx.default_dtype)
case PsAssignment(lhs, rhs):
tc_lhs = TypeContext(require_nonconst=True, is_lhs=True)
self.visit_expr(lhs, tc_lhs)
infer_lhs = isinstance(lhs, PsSymbolExpr) and lhs.symbol.dtype is None
tc_lhs = TypeContext(require_nonconst=True)
if infer_lhs:
tc_lhs.infer_dtype(lhs)
else:
self.visit_expr(lhs, tc_lhs)
assert tc_lhs.target_type is not None
tc_rhs = TypeContext(target_type=tc_lhs.target_type)
self.visit_expr(rhs, tc_rhs)
if tc_rhs.target_type is None:
tc_rhs.apply_dtype(self._ctx.default_dtype)
if tc_lhs.target_type is None:
if infer_lhs:
if tc_rhs.target_type is None:
tc_rhs.apply_dtype(self._ctx.default_dtype)
assert tc_rhs.target_type is not None
tc_lhs.apply_dtype(deconstify(tc_rhs.target_type))
......@@ -398,15 +394,9 @@ class Typifier:
"""
match expr:
case PsSymbolExpr(symb):
if tc.is_lhs:
if symb.dtype is not None:
tc.apply_dtype(symb.dtype, expr)
elif tc.is_lhs:
tc.infer_dtype(expr)
else:
if symb.dtype is None:
symb.dtype = self._ctx.default_dtype
tc.apply_dtype(symb.dtype, expr)
if symb.dtype is None:
symb.dtype = self._ctx.default_dtype
tc.apply_dtype(symb.dtype, expr)
case PsConstantExpr(c):
if c.dtype is not None:
......
......@@ -16,6 +16,7 @@ from pystencils.backend.ast.structural import (
from pystencils.backend.ast.expressions import (
PsConstantExpr,
PsSymbolExpr,
PsSubscript,
PsBinOp,
PsAnd,
PsOr,
......@@ -27,12 +28,12 @@ from pystencils.backend.ast.expressions import (
PsGt,
PsLt,
PsCall,
PsTernary
PsTernary,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
from pystencils.types import constify, create_type, create_numeric_type
from pystencils.types.quick import Fp, Int, Bool
from pystencils.types.quick import Fp, Int, Bool, Arr
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
......@@ -242,9 +243,11 @@ def test_lhs_inference():
assert ctx.get_symbol("z").dtype == Fp(16)
assert fasm.lhs.dtype == Fp(16)
fasm = PsDeclaration(PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q)))
fasm = PsDeclaration(
PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
)
fasm = typify(fasm)
assert ctx.get_symbol("r").dtype == Bool()
assert fasm.lhs.dtype == constify(Bool())
assert fasm.rhs.dtype == constify(Bool())
......@@ -282,6 +285,26 @@ def test_erronous_typing():
typify(fasm)
def test_invalid_indices():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
typify = Typifier(ctx)
arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64))))
x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
# Using default-typed symbols as array indices is illegal when the default type is a float
fasm = PsAssignment(PsSubscript(arr, x + y), z)
with pytest.raises(TypificationError):
typify(fasm)
fasm = PsAssignment(z, PsSubscript(arr, x + y))
with pytest.raises(TypificationError):
typify(fasm)
def test_typify_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
......@@ -410,7 +433,7 @@ def test_invalid_conditions():
with pytest.raises(TypificationError):
typify(cond)
def test_typify_ternary():
ctx = KernelCreationContext()
typify = Typifier(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment