diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index c25579029554e1c894c3286c5251eb39f2a1f253..b0dace4e0cf42f21fcb335a58d3e7150185defd2 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from typing import Iterable, Sequence, cast from types import NoneType @@ -17,6 +17,24 @@ class PsStructuralNode(PsAstNode, ABC): This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. """ + def clone(self) -> PsStructuralNode: + """Clone this structure node. + + .. note:: + Subclasses of `PsStructuralNode` should not override this method, + but implement `_clone_structural` instead. + That implementation shall call `clone` on any of its children. + """ + return self._clone_structural() + + @abstractmethod + def _clone_structural(self) -> PsStructuralNode: + """Implementation of structural node cloning. + + :meta public: + """ + pass + class PsBlock(PsStructuralNode): __match_args__ = ("statements",) @@ -38,8 +56,8 @@ class PsBlock(PsStructuralNode): def set_child(self, idx: int, c: PsAstNode): self._statements[idx] = failing_cast(PsStructuralNode, c) - def clone(self) -> PsBlock: - return PsBlock([failing_cast(PsStructuralNode, stmt.clone()) for stmt in self._statements]) + def _clone_structural(self) -> PsBlock: + return PsBlock([stmt._clone_structural() for stmt in self._statements]) @property def statements(self) -> list[PsStructuralNode]: @@ -68,7 +86,7 @@ class PsStatement(PsStructuralNode): def expression(self, expr: PsExpression): self._expression = expr - def clone(self) -> PsStatement: + def _clone_structural(self) -> PsStatement: return PsStatement(self._expression.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -110,7 +128,7 @@ class PsAssignment(PsStructuralNode): def rhs(self, expr: PsExpression): self._rhs = expr - def clone(self) -> PsAssignment: + def _clone_structural(self) -> PsAssignment: return PsAssignment(self._lhs.clone(), self._rhs.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -150,7 +168,7 @@ class PsDeclaration(PsAssignment): def declared_symbol(self) -> PsSymbol: return cast(PsSymbolExpr, self._lhs).symbol - def clone(self) -> PsDeclaration: + def _clone_structural(self) -> PsDeclaration: return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) def set_child(self, idx: int, c: PsAstNode): @@ -223,7 +241,7 @@ class PsLoop(PsStructuralNode): def body(self, block: PsBlock): self._body = block - def clone(self) -> PsLoop: + def _clone_structural(self) -> PsLoop: return PsLoop( self._ctr.clone(), self._start.clone(), @@ -291,7 +309,7 @@ class PsConditional(PsStructuralNode): def branch_false(self, block: PsBlock | None): self._branch_false = block - def clone(self) -> PsConditional: + def _clone_structural(self) -> PsConditional: return PsConditional( self._condition.clone(), self._branch_true.clone(), @@ -344,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): def text(self) -> str: return self._text - def clone(self) -> PsPragma: + def _clone_structural(self) -> PsPragma: return PsPragma(self.text) def structurally_equal(self, other: PsAstNode) -> bool: @@ -369,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): def lines(self) -> tuple[str, ...]: return self._lines - def clone(self) -> PsComment: + def _clone_structural(self) -> PsComment: return PsComment(self._text) def structurally_equal(self, other: PsAstNode) -> bool: