diff --git a/src/pystencils/backend/constants.py b/src/pystencils/backend/constants.py index 6dc07842ff12a1aadeedd37b15e3ee8c8185ebe9..125c1149ba7fec7ce279afef089229f70c75eb18 100644 --- a/src/pystencils/backend/constants.py +++ b/src/pystencils/backend/constants.py @@ -59,9 +59,14 @@ class PsConstant: @property def dtype(self) -> PsNumericType | None: + """This constant's data type, or ``None`` if it is untyped. + + The data type of a constant always has ``const == True``. + """ return self._dtype def get_dtype(self) -> PsNumericType: + """Retrieve this constant's data type, throwing an exception if the constant is untyped.""" if self._dtype is None: raise PsInternalCompilerError("Data type of constant was not set.") return self._dtype diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index d2c93e22109c4970ad5bb8915e7c6c8cc84173d6..0ea0e9d512a09e95f441043ea9e273340be7ccba 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -12,7 +12,7 @@ from ...types import ( PsDereferencableType, PsPointerType, PsBoolType, - deconstify, + constify, ) from ..ast.structural import ( PsAstNode, @@ -21,6 +21,7 @@ from ..ast.structural import ( PsConditional, PsExpression, PsAssignment, + PsDeclaration, ) from ..ast.expressions import ( PsArrayAccess, @@ -54,10 +55,16 @@ NodeT = TypeVar("NodeT", bound=PsAstNode) class TypeContext: - def __init__(self, target_type: PsType | None = None): - self._target_type = deconstify(target_type) if target_type is not None else None + def __init__( + self, target_type: PsType | None = None, require_nonconst: bool = False + ): + self._require_nonconst = require_nonconst self._deferred_exprs: list[PsExpression] = [] + self._target_type = ( + self._fix_constness(target_type) if target_type is not None else None + ) + def apply_dtype(self, expr: PsExpression | None, dtype: PsType): """Applies the given ``dtype`` to the given expression inside this type context. @@ -67,7 +74,7 @@ class TypeContext: to all deferred expressions. """ - dtype = deconstify(dtype) + dtype = self._fix_constness(dtype) if self._target_type is not None and dtype != self._target_type: raise TypificationError( @@ -80,14 +87,7 @@ class TypeContext: self._propagate_target_type() if expr is not None: - if expr.dtype is None: - self._apply_target_type(expr) - elif deconstify(expr.dtype) != self._target_type: - raise TypificationError( - "Type conflict: Predefined expression type did not match the context's target type\n" - f" Expression type: {dtype}\n" - f" Target type: {self._target_type}" - ) + self._apply_target_type(expr) def infer_dtype(self, expr: PsExpression): """Infer the data type for the given expression. @@ -113,7 +113,7 @@ class TypeContext: assert self._target_type is not None if expr.dtype is not None: - if deconstify(expr.dtype) != self.target_type: + if not self._compatible(expr.dtype): raise TypificationError( f"Type mismatch at expression {expr}: Expression type did not match the context's target type\n" f" Expression type: {expr.dtype}\n" @@ -128,7 +128,7 @@ class TypeContext: ) if c.dtype is None: expr.constant = c.interpret_as(self._target_type) - elif deconstify(c.dtype) != self._target_type: + elif not self._compatible(c.dtype): raise TypificationError( f"Type mismatch at constant {c}: Constant type did not match the context's target type\n" f" Constant type: {c.dtype}\n" @@ -136,7 +136,14 @@ class TypeContext: ) case PsSymbolExpr(symb): - symb.apply_dtype(self._target_type) + if symb.dtype is None: + symb.dtype = self._target_type + elif not self._compatible(symb.dtype): + raise TypificationError( + f"Type mismatch at symbol {symb}: Symbol type did not match the context's target type\n" + f" Symbol type: {symb.dtype}\n" + f" Target type: {self._target_type}" + ) case ( PsIntDiv() @@ -151,9 +158,30 @@ class TypeContext: f" Expression: {expr}" f" Type Context: {self._target_type}" ) - - expr.dtype = self._target_type # endif + expr.dtype = self._target_type + + def _compatible(self, dtype: PsType): + assert self._target_type is not None + if self._target_type.const: + return constify(dtype) == self._target_type + else: + return dtype == self._target_type + + def _fix_constness(self, dtype: PsType, expr: PsExpression | None = None): + if self._require_nonconst: + if dtype.const: + if expr is None: + raise TypificationError( + f"Type mismatch: Encountered {dtype} in non-constant context." + ) + else: + raise TypificationError( + f"Type mismatch at expression {expr}: Encountered {dtype} in non-constant context." + ) + return dtype + else: + return constify(dtype) @property def target_type(self) -> PsType | None: @@ -213,13 +241,21 @@ class Typifier: for s in statements: self.visit(s) - case PsAssignment(lhs, rhs): + case PsDeclaration(lhs, rhs): tc = TypeContext() # LHS defines target type; type context carries it to RHS self.visit_expr(lhs, tc) assert tc.target_type is not None self.visit_expr(rhs, tc) + case PsAssignment(lhs, rhs): + tc_lhs = TypeContext(require_nonconst=True) + self.visit_expr(lhs, tc_lhs) + assert tc_lhs.target_type is not None + + tc_rhs = TypeContext(tc_lhs.target_type, require_nonconst=False) + self.visit_expr(rhs, tc_rhs) + case PsConditional(cond, branch_true, branch_false): cond_tc = TypeContext(PsBoolType(const=True)) self.visit_expr(cond, cond_tc) @@ -233,10 +269,10 @@ class Typifier: if ctr.symbol.dtype is None: ctr.symbol.apply_dtype(self._ctx.index_dtype) - tc = TypeContext(ctr.symbol.dtype) - self.visit_expr(start, tc) - self.visit_expr(stop, tc) - self.visit_expr(step, tc) + tc_lhs = TypeContext(ctr.symbol.dtype) + self.visit_expr(start, tc_lhs) + self.visit_expr(stop, tc_lhs) + self.visit_expr(step, tc_lhs) self.visit(body) @@ -247,10 +283,10 @@ class Typifier: """Recursive processing of expression nodes""" match expr: case PsSymbolExpr(_): - if expr.dtype is None: + if expr.symbol.dtype is None: tc.apply_dtype(expr, self._ctx.default_dtype) else: - tc.apply_dtype(expr, expr.dtype) + tc.apply_dtype(expr, expr.symbol.dtype) case PsConstantExpr(_): tc.infer_dtype(expr) @@ -325,8 +361,12 @@ class Typifier: raise TypificationError( f"Aggregate of type {aggr_type} does not have a member {member}." ) + + member_type = member.dtype + if aggr_type.const: + member_type = constify(member_type) - tc.apply_dtype(expr, member.dtype) + tc.apply_dtype(expr, member_type) case PsBinOp(op1, op2): self.visit_expr(op1, tc) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index ef746c614901a2a15a3b63e117d0c7cec61b9676..6a1a028602a8410fdc36deb578aae6444c8d48a8 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -10,7 +10,7 @@ from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp from pystencils.backend.constants import PsConstant from pystencils.types import constify -from pystencils.types.quick import Fp, create_numeric_type +from pystencils.types.quick import Fp, create_type, create_numeric_type from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -38,14 +38,14 @@ def test_typify_simple(): assert isinstance(fasm, PsDeclaration) def check(expr): - assert expr.dtype == ctx.default_dtype + assert expr.dtype == constify(ctx.default_dtype) match expr: case PsConstantExpr(cs): assert cs.value == 2 assert cs.dtype == constify(ctx.default_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" - assert symb.dtype == ctx.default_dtype + assert symb.dtype == constify(ctx.default_dtype) case PsBinOp(op1, op2): check(op1) check(op2) @@ -56,6 +56,82 @@ def test_typify_simple(): check(fasm.rhs) +def test_rhs_constness(): + default_type = Fp(32) + ctx = KernelCreationContext(default_dtype=default_type) + + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + f = Field.create_generic( + "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM + ) + f_const = Field.create_generic( + "f_const", + 1, + index_shape=(1,), + dtype=constify(default_type), + field_type=FieldType.CUSTOM, + ) + + x, y, z = sp.symbols("x, y, z") + + # Right-hand sides should always get const types + asm = typify(freeze(Assignment(x, f.absolute_access([0], [0])))) + assert asm.rhs.get_dtype().const + + asm = typify( + freeze( + Assignment( + f.absolute_access([0], [0]), + f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y, + ) + ) + ) + assert asm.rhs.get_dtype().const + + +def test_lhs_constness(): + default_type = Fp(32) + ctx = KernelCreationContext(default_dtype=default_type) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + f = Field.create_generic( + "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM + ) + f_const = Field.create_generic( + "f_const", + 1, + index_shape=(1,), + dtype=constify(default_type), + field_type=FieldType.CUSTOM, + ) + + x, y, z = sp.symbols("x, y, z") + q = TypedSymbol("q", Fp(32, const=True)) + + # Assignment RHS may not be const + asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) + assert not asm.lhs.get_dtype().const + + # Cannot assign to const left-hand side + with pytest.raises(TypificationError): + _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y))) + + # with pytest.raises(TypificationError): + # _ = typify(freeze(Assignment(q, x - z))) + + np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) + struct_type = constify(create_type(np_struct)) + struct_field = Field.create_generic( + "struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM + ) + + with pytest.raises(TypificationError): + _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x))) + + def test_typify_structs(): ctx = KernelCreationContext(default_dtype=Fp(32)) freeze = FreezeExpressions(ctx) @@ -87,14 +163,14 @@ def test_contextual_typing(): expr = typify(expr) def check(expr): - assert expr.dtype == ctx.default_dtype + assert expr.dtype == constify(ctx.default_dtype) match expr: case PsConstantExpr(cs): assert cs.value in (2, 3, -4) assert cs.dtype == constify(ctx.default_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" - assert symb.dtype == ctx.default_dtype + assert symb.dtype == constify(ctx.default_dtype) case PsBinOp(op1, op2): check(op1) check(op2) @@ -199,6 +275,9 @@ def test_typify_constant_clones(): expr_clone = expr.clone() expr = typify(expr) - + assert expr_clone.operand1.dtype is None assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None + + +# test_lhs_constness()