diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index c25579029554e1c894c3286c5251eb39f2a1f253..b0dace4e0cf42f21fcb335a58d3e7150185defd2 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from abc import ABC
+from abc import ABC, abstractmethod
 from typing import Iterable, Sequence, cast
 from types import NoneType
 
@@ -17,6 +17,24 @@ class PsStructuralNode(PsAstNode, ABC):
     This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
     """
 
+    def clone(self) -> PsStructuralNode:
+        """Clone this structure node.
+
+        .. note::
+            Subclasses of `PsStructuralNode` should not override this method,
+            but implement `_clone_structural` instead.
+            That implementation shall call `clone` on any of its children.
+        """
+        return self._clone_structural()
+
+    @abstractmethod
+    def _clone_structural(self) -> PsStructuralNode:
+        """Implementation of structural node cloning.
+
+        :meta public:
+        """
+        pass
+
 
 class PsBlock(PsStructuralNode):
     __match_args__ = ("statements",)
@@ -38,8 +56,8 @@ class PsBlock(PsStructuralNode):
     def set_child(self, idx: int, c: PsAstNode):
         self._statements[idx] = failing_cast(PsStructuralNode, c)
 
-    def clone(self) -> PsBlock:
-        return PsBlock([failing_cast(PsStructuralNode, stmt.clone()) for stmt in self._statements])
+    def _clone_structural(self) -> PsBlock:
+        return PsBlock([stmt._clone_structural() for stmt in self._statements])
 
     @property
     def statements(self) -> list[PsStructuralNode]:
@@ -68,7 +86,7 @@ class PsStatement(PsStructuralNode):
     def expression(self, expr: PsExpression):
         self._expression = expr
 
-    def clone(self) -> PsStatement:
+    def _clone_structural(self) -> PsStatement:
         return PsStatement(self._expression.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -110,7 +128,7 @@ class PsAssignment(PsStructuralNode):
     def rhs(self, expr: PsExpression):
         self._rhs = expr
 
-    def clone(self) -> PsAssignment:
+    def _clone_structural(self) -> PsAssignment:
         return PsAssignment(self._lhs.clone(), self._rhs.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -150,7 +168,7 @@ class PsDeclaration(PsAssignment):
     def declared_symbol(self) -> PsSymbol:
         return cast(PsSymbolExpr, self._lhs).symbol
 
-    def clone(self) -> PsDeclaration:
+    def _clone_structural(self) -> PsDeclaration:
         return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
 
     def set_child(self, idx: int, c: PsAstNode):
@@ -223,7 +241,7 @@ class PsLoop(PsStructuralNode):
     def body(self, block: PsBlock):
         self._body = block
 
-    def clone(self) -> PsLoop:
+    def _clone_structural(self) -> PsLoop:
         return PsLoop(
             self._ctr.clone(),
             self._start.clone(),
@@ -291,7 +309,7 @@ class PsConditional(PsStructuralNode):
     def branch_false(self, block: PsBlock | None):
         self._branch_false = block
 
-    def clone(self) -> PsConditional:
+    def _clone_structural(self) -> PsConditional:
         return PsConditional(
             self._condition.clone(),
             self._branch_true.clone(),
@@ -344,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     def text(self) -> str:
         return self._text
 
-    def clone(self) -> PsPragma:
+    def _clone_structural(self) -> PsPragma:
         return PsPragma(self.text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -369,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     def lines(self) -> tuple[str, ...]:
         return self._lines
 
-    def clone(self) -> PsComment:
+    def _clone_structural(self) -> PsComment:
         return PsComment(self._text)
 
     def structurally_equal(self, other: PsAstNode) -> bool: