diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 5c8fca9adf9546683f3f7489eb137a00f1523b9d..31d8ea192269a9a9947457814ff5e58d63f61c14 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -17,7 +17,7 @@ class PsStructuralNode(PsAstNode, ABC):
     This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
     """
 
-    def clone(self) -> PsStructuralNode:
+    def clone(self):
         """Clone this structure node.
 
         .. note::
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 935ac38e32d6cc77276e607f38e4b21d8062a70c..bd782422f1fa80b96ec7cf69473fda2b1f45c3d6 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 from dataclasses import dataclass
 
-from typing import Sequence, cast
+from typing import Sequence
 from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext
@@ -55,13 +55,12 @@ class InsertPragmasAtLoops:
             self._insertions[ins.loop_nesting_depth].append(ins)
 
     def __call__(self, node: PsAstNode) -> PsAstNode:
-        is_loop = isinstance(node, PsLoop)
-        if is_loop:
-            node = PsBlock([cast(PsLoop, node)])
+        if isinstance(node, PsLoop):
+            node = PsBlock([node])
 
         self.visit(node, Nesting(0))
 
-        if is_loop and len(node.children) == 1:
+        if isinstance(node, PsLoop) and len(node.children) == 1:
             node = node.children[0]
 
         return node
diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py
index 9621699d03c5ee5a06f056da961c4857f015ee04..c793c424d2417cbbdcc0cf3782e696c7c9226bb6 100644
--- a/src/pystencils/backend/transformations/ast_vectorizer.py
+++ b/src/pystencils/backend/transformations/ast_vectorizer.py
@@ -269,12 +269,24 @@ class AstVectorizer:
         """
         return self.visit(node, vc)
 
+    @overload
+    def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode:
+        pass
+
+    @overload
+    def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression:
+        pass
+
+    @overload
+    def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
+        pass
+
     def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
         """Vectorize a subtree."""
 
         match node:
             case PsBlock(stmts):
-                return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts])
+                return PsBlock([self.visit(n, vc) for n in stmts])
 
             case PsExpression():
                 return self.visit_expr(node, vc)
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
index ca24e49b774d58007c5652a72aaa0ec4d8f4c9f6..69dd1dd11d726e597c15ece772846ba8cd84acba 100644
--- a/src/pystencils/backend/transformations/eliminate_branches.py
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -68,7 +68,7 @@ class EliminateBranches:
     def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode:
         match node:
             case PsLoop(_, _, _, _, body):
-                ec.enclosing_loops.append(cast(PsLoop, node))
+                ec.enclosing_loops.append(node)
                 self.visit(body, ec)
                 ec.enclosing_loops.pop()
 
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index b66efe4f25b3ef37019595abf296e42c3343f272..3a07cb56fcb8f1c60107b5b1883c679191429e7e 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -36,6 +36,7 @@ from ..ast.expressions import (
 )
 from ..ast.vector import PsVecBroadcast
 from ..ast.util import AstEqWrapper
+from ..exceptions import PsInternalCompilerError
 
 from ..constants import PsConstant
 from ..memory import PsSymbol
@@ -138,13 +139,18 @@ class EliminateConstants:
         node = self.visit(node, ecc)
 
         if ecc.extractions:
+            if not isinstance(node, PsStructuralNode):
+                raise PsInternalCompilerError(
+                    f"Cannot extract constant expressions from outermost node {node}"
+                )
+
             prepend_decls = [
                 PsDeclaration(PsExpression.make(symb), expr)
                 for symb, expr in ecc.extractions
             ]
 
             if not isinstance(node, PsBlock):
-                node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)])
+                node = PsBlock(prepend_decls + [node])
             else:
                 node.children = prepend_decls + list(node.children)
 
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index 9637485dde4cc296cedbc0e24329949fda027292..f7fe81ad736981bee6f38427fbd4face73f0c455 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations:
         """Search the outermost loop and start the hoisting cascade there."""
         match node:
             case PsLoop():
-                temp_block = PsBlock([cast(PsLoop, node)])
+                temp_block = PsBlock([node])
                 temp_block = cast(PsBlock, self.visit(temp_block))
                 if temp_block.statements == [node]:
                     return node
diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py
index 6b518a30de31acac67c061ead1683c1d6ab06816..e1e4fea502c08de86e13de5e3c251f1b7a7d0ee6 100644
--- a/src/pystencils/backend/transformations/loop_vectorizer.py
+++ b/src/pystencils/backend/transformations/loop_vectorizer.py
@@ -213,7 +213,7 @@ class LoopVectorizer:
 
                 trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr)
                 trailing_loop_body = substitute_symbols(
-                    loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)}
+                    loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)}
                 )
                 trailing_loop = PsLoop(
                     PsExpression.make(trailing_ctr),
diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py
index 59241c295f42eeaf60f4cd03a5138214fdbd6c50..8dff9e45ec283fc6c3712c2e77ff56a9b2aaeae5 100644
--- a/src/pystencils/backend/transformations/rewrite.py
+++ b/src/pystencils/backend/transformations/rewrite.py
@@ -2,7 +2,7 @@ from typing import overload
 
 from ..memory import PsSymbol
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock
+from ..ast.structural import PsStructuralNode, PsBlock
 from ..ast.expressions import PsExpression, PsSymbolExpr
 
 
@@ -18,6 +18,13 @@ def substitute_symbols(
     pass
 
 
+@overload
+def substitute_symbols(
+    node: PsStructuralNode, subs: dict[PsSymbol, PsExpression]
+) -> PsStructuralNode:
+    pass
+
+
 @overload
 def substitute_symbols(
     node: PsAstNode, subs: dict[PsSymbol, PsExpression]