diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index e658108e08148a78dd1fd8e834d78c3edcce7c9b..6a04f4f95a390a33cfb59a62f67c04b9c8dd54c6 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -34,6 +34,10 @@ class PsExpression(PsAstNode, ABC): The type annotations are used by various transformation passes to make decisions, e.g. in function materialization and intrinsic selection. + + .. attention:: + The ``structurally_equal`` check currently does not take expression data types into + account. This may change in the future. """ def __init__(self, dtype: PsType | None = None) -> None: @@ -94,8 +98,26 @@ class PsExpression(PsAstNode, ABC): else: raise ValueError(f"Cannot make expression out of {obj}") + def clone(self): + """Clone this expression. + + .. note:: + Subclasses of `PsExpression` should not override this method, + but implement `_clone_expr` instead. + That implementation shall call `clone` on any of its subexpressions, + but does not need to fix the `dtype` property. + The `dtype` is correctly applied by `PsExpression.clone` internally. + """ + cloned = self._clone_expr() + cloned._dtype = self.dtype + return cloned + @abstractmethod - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: + """Implementation of expression cloning. + + :meta public: + """ pass @@ -121,7 +143,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression): def symbol(self, symbol: PsSymbol): self._symbol = symbol - def clone(self) -> PsSymbolExpr: + def _clone_expr(self) -> PsSymbolExpr: return PsSymbolExpr(self._symbol) def structurally_equal(self, other: PsAstNode) -> bool: @@ -149,7 +171,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def constant(self, c: PsConstant): self._constant = c - def clone(self) -> PsConstantExpr: + def _clone_expr(self) -> PsConstantExpr: return PsConstantExpr(self._constant) def structurally_equal(self, other: PsAstNode) -> bool: @@ -177,7 +199,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): def literal(self, lit: PsLiteral): self._literal = lit - def clone(self) -> PsLiteralExpr: + def _clone_expr(self) -> PsLiteralExpr: return PsLiteralExpr(self._literal) def structurally_equal(self, other: PsAstNode) -> bool: @@ -240,7 +262,7 @@ class PsBufferAcc(PsLvalue, PsExpression): else: self._index[idx - 1] = failing_cast(PsExpression, c) - def clone(self) -> PsBufferAcc: + def _clone_expr(self) -> PsBufferAcc: return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index]) def __repr__(self) -> str: @@ -277,7 +299,7 @@ class PsSubscript(PsLvalue, PsExpression): def index(self, idx: Sequence[PsExpression]): self._index = list(idx) - def clone(self) -> PsSubscript: + def _clone_expr(self) -> PsSubscript: return PsSubscript(self._arr.clone(), [i.clone() for i in self._index]) def get_children(self) -> tuple[PsAstNode, ...]: @@ -322,7 +344,7 @@ class PsMemAcc(PsLvalue, PsExpression): def offset(self, expr: PsExpression): self._offset = expr - def clone(self) -> PsMemAcc: + def _clone_expr(self) -> PsMemAcc: return PsMemAcc(self._ptr.clone(), self._offset.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -374,7 +396,7 @@ class PsVectorMemAcc(PsMemAcc): def get_vector_type(self) -> PsVectorType: return cast(PsVectorType, self._dtype) - def clone(self) -> PsVectorMemAcc: + def _clone_expr(self) -> PsVectorMemAcc: return PsVectorMemAcc( self._ptr.clone(), self._offset.clone(), @@ -419,7 +441,7 @@ class PsLookup(PsExpression, PsLvalue): def member_name(self, name: str): self._name = name - def clone(self) -> PsLookup: + def _clone_expr(self) -> PsLookup: return PsLookup(self._aggregate.clone(), self._member_name) def get_children(self) -> tuple[PsAstNode, ...]: @@ -469,7 +491,7 @@ class PsCall(PsExpression): self._args = list(exprs) - def clone(self) -> PsCall: + def _clone_expr(self) -> PsCall: return PsCall(self._function, [arg.clone() for arg in self._args]) def get_children(self) -> tuple[PsAstNode, ...]: @@ -513,7 +535,7 @@ class PsTernary(PsExpression): def case_else(self) -> PsExpression: return self._else - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone()) def get_children(self) -> tuple[PsExpression, ...]: @@ -563,7 +585,7 @@ class PsUnOp(PsExpression): def operand(self, expr: PsExpression): self._operand = expr - def clone(self) -> PsUnOp: + def _clone_expr(self) -> PsUnOp: return type(self)(self._operand.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -617,7 +639,7 @@ class PsCast(PsUnOp): def target_type(self, dtype: PsType): self._target_type = dtype - def clone(self) -> PsUnOp: + def _clone_expr(self) -> PsUnOp: return PsCast(self._target_type, self._operand.clone()) def structurally_equal(self, other: PsAstNode) -> bool: @@ -653,7 +675,7 @@ class PsBinOp(PsExpression): def operand2(self, expr: PsExpression): self._op2 = expr - def clone(self) -> PsBinOp: + def _clone_expr(self) -> PsBinOp: return type(self)(self._op1.clone(), self._op2.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -838,7 +860,7 @@ class PsArrayInitList(PsExpression): def set_child(self, idx: int, c: PsAstNode): self._items.flat[idx] = failing_cast(PsExpression, c) - def clone(self) -> PsExpression: + def _clone_expr(self) -> PsExpression: return PsArrayInitList( np.array([expr.clone() for expr in self.children]).reshape( # type: ignore self._items.shape diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 57244c03b6413d3fd8c7b521618cb9021b0e2037..e2f202f6594b8591903531ba795be251b6b544d6 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -320,7 +320,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. Args: - text: The pragma's text, without the ``#pragma ``. + text: The pragmas text, without the ``#pragma``. """ __match_args__ = ("text",) diff --git a/src/pystencils/backend/extensions/cpp.py b/src/pystencils/backend/extensions/cpp.py index 1055b79e9ab197d62c4307b70ac5b2a71c13f139..025f4a3fb61d51d7fd9c485b597a671ae2cfc231 100644 --- a/src/pystencils/backend/extensions/cpp.py +++ b/src/pystencils/backend/extensions/cpp.py @@ -25,7 +25,7 @@ class CppMethodCall(PsForeignExpression): return super().structurally_equal(other) and self._method == other._method - def clone(self) -> CppMethodCall: + def _clone_expr(self) -> CppMethodCall: return CppMethodCall( cast(PsExpression, self.children[0]), self._method, diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 975ea5a6041fbd3e7c408fd9885a79b7b23fcefe..debcc3cf61d4b419f6387dfaca8e1eac655406f4 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -23,6 +23,7 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, + PsStatement, PsEmptyLeafMixIn, ) from ..ast.expressions import ( @@ -301,6 +302,12 @@ class Typifier: for s in statements: self.visit(s) + case PsStatement(expr): + tc = TypeContext() + self.visit_expr(expr, tc) + if tc.target_type is None: + tc.apply_dtype(self._ctx.default_dtype) + case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList): # Special treatment for array declarations assert isinstance(lhs, PsSymbolExpr) diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py index 5771eaca84413708c68c4f7941e07cbd63403e9e..8e7d27f58265c08461cba6b05373848112a6fee7 100644 --- a/src/pystencils/types/parsing.py +++ b/src/pystencils/types/parsing.py @@ -8,6 +8,7 @@ from .types import ( PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, + PsBoolType, ) UserTypeSpec = str | type | np.dtype | PsType @@ -143,6 +144,9 @@ def parse_type_string(s: str) -> PsType: def parse_type_name(typename: str, const: bool): match typename: + case "bool": + return PsBoolType(const=const) + case "int" | "int64" | "int64_t": return PsSignedIntegerType(64, const=const) case "int32" | "int32_t": diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index cf7fd3f31b13f0fbbac3b350f769e6993ab44d9d..2408b8d867038a0f2fd5c4d8a5f22bc82312c701 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -1,6 +1,7 @@ import pytest -from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory +from pystencils import create_type +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, Typifier from pystencils.backend.memory import PsSymbol, BufferBasePtr from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import ( @@ -15,6 +16,7 @@ from pystencils.backend.ast.expressions import ( from pystencils.backend.ast.structural import ( PsStatement, PsAssignment, + PsDeclaration, PsBlock, PsConditional, PsComment, @@ -25,15 +27,25 @@ from pystencils.types.quick import Fp, Ptr def test_cloning(): - x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"] + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y, z, m = [PsExpression.make(ctx.get_symbol(name)) for name in "xyzm"] + q = PsExpression.make(ctx.get_symbol("q", create_type("bool"))) + a, b, c = [PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc"] c1 = PsExpression.make(PsConstant(3.0)) c2 = PsExpression.make(PsConstant(-1.0)) - one = PsExpression.make(PsConstant(1)) + one_f = PsExpression.make(PsConstant(1.0)) + one_i = PsExpression.make(PsConstant(1)) def check(orig, clone): assert not (orig is clone) assert type(orig) is type(clone) assert orig.structurally_equal(clone) + + if isinstance(orig, PsExpression): + # Regression: Expression data types used to not be cloned + assert orig.dtype == clone.dtype for c1, c2 in zip(orig.children, clone.children, strict=True): check(c1, c2) @@ -49,18 +61,21 @@ def test_cloning(): PsAssignment(y, x / c1), PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]), PsConditional( - y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + q, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) + ), + PsDeclaration( + m, + PsArrayInitList([ + [x, y, one_f + x], + [one_f, c2, z] + ]) ), - PsArrayInitList([ - [x, y, one + x], - [one, c2, z] - ]), PsPragma("omp parallel for"), PsLoop( - x, - y, - z, - one, + a, + b, + c, + one_i, PsBlock( [ PsComment("Loop body"), @@ -68,13 +83,14 @@ def test_cloning(): PsAssignment(x, y), PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( - PsMemAcc(PsCast(Ptr(Fp(32)), z), one) - + PsSubscript(z, (one + one + one, y + one)) + PsMemAcc(PsCast(Ptr(Fp(32)), z), one_i) + + PsCast(Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i))) ), ] ), ), ]: + ast = typify(ast) ast_clone = ast.clone() check(ast, ast_clone)