From 65f310985cdba73caee4f8b0d4c0cc50850174c0 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Fri, 21 Feb 2025 19:57:24 +0100
Subject: [PATCH 1/7] Introduce structural ast node trait and employ in
 structural.py

---
 src/pystencils/backend/ast/structural.py | 31 +++++++++++++++---------
 1 file changed, 20 insertions(+), 11 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 2c79f4f46..98ec72039 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -1,4 +1,6 @@
 from __future__ import annotations
+
+from abc import ABC
 from typing import Iterable, Sequence, cast
 from types import NoneType
 
@@ -9,10 +11,17 @@ from ..memory import PsSymbol
 from .util import failing_cast
 
 
-class PsBlock(PsAstNode):
+class PsStructuralAstNode(PsAstNode, ABC):
+    """Base class for structural nodes in the pystencils AST.
+
+    This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
+    """
+
+
+class PsBlock(PsStructuralAstNode):
     __match_args__ = ("statements",)
 
-    def __init__(self, cs: Iterable[PsAstNode]):
+    def __init__(self, cs: Iterable[PsStructuralAstNode]):
         self._statements = list(cs)
 
     @property
@@ -27,17 +36,17 @@ class PsBlock(PsAstNode):
         return tuple(self._statements)
 
     def set_child(self, idx: int, c: PsAstNode):
-        self._statements[idx] = c
+        self._statements[idx] = failing_cast(PsStructuralAstNode, c)
 
     def clone(self) -> PsBlock:
         return PsBlock([stmt.clone() for stmt in self._statements])
 
     @property
-    def statements(self) -> list[PsAstNode]:
+    def statements(self) -> list[PsStructuralAstNode]:
         return self._statements
 
     @statements.setter
-    def statements(self, stm: Sequence[PsAstNode]):
+    def statements(self, stm: Sequence[PsStructuralAstNode]):
         self._statements = list(stm)
 
     def __repr__(self) -> str:
@@ -45,7 +54,7 @@ class PsBlock(PsAstNode):
         return f"PsBlock( {contents} )"
 
 
-class PsStatement(PsAstNode):
+class PsStatement(PsStructuralAstNode):
     __match_args__ = ("expression",)
 
     def __init__(self, expr: PsExpression):
@@ -71,7 +80,7 @@ class PsStatement(PsAstNode):
         self._expression = failing_cast(PsExpression, c)
 
 
-class PsAssignment(PsAstNode):
+class PsAssignment(PsStructuralAstNode):
     __match_args__ = (
         "lhs",
         "rhs",
@@ -157,7 +166,7 @@ class PsDeclaration(PsAssignment):
         return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
 
 
-class PsLoop(PsAstNode):
+class PsLoop(PsStructuralAstNode):
     __match_args__ = ("counter", "start", "stop", "step", "body")
 
     def __init__(
@@ -243,7 +252,7 @@ class PsLoop(PsAstNode):
                 assert False, "unreachable code"
 
 
-class PsConditional(PsAstNode):
+class PsConditional(PsStructuralAstNode):
     """Conditional branch"""
 
     __match_args__ = ("condition", "branch_true", "branch_false")
@@ -317,7 +326,7 @@ class PsEmptyLeafMixIn:
     pass
 
 
-class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
+class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode):
     """A C/C++ preprocessor pragma.
 
     Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
@@ -345,7 +354,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
         return self._text == other._text
 
 
-class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
+class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode):
     __match_args__ = ("lines",)
 
     def __init__(self, text: str) -> None:
-- 
GitLab


From a4b86fdd0611be29ed7c907ce41f46990bbef43e Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 27 Feb 2025 17:32:24 +0100
Subject: [PATCH 2/7] Try fixing typecheck for newly introduced structural ast
 nodes

---
 src/pystencils/backend/ast/structural.py               |  4 ++--
 src/pystencils/backend/kernelcreation/freeze.py        |  3 ++-
 src/pystencils/backend/transformations/add_pragmas.py  |  8 ++++----
 .../backend/transformations/ast_vectorizer.py          |  3 ++-
 .../backend/transformations/eliminate_branches.py      | 10 ++++++----
 .../backend/transformations/eliminate_constants.py     |  4 ++--
 .../transformations/hoist_loop_invariant_decls.py      | 10 +++++-----
 7 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 98ec72039..5966304e2 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -30,7 +30,7 @@ class PsBlock(PsStructuralAstNode):
 
     @children.setter
     def children(self, cs: Sequence[PsAstNode]):
-        self._statements = list(cs)
+        self._statements = list([failing_cast(PsStructuralAstNode, c) for c in cs])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
         return tuple(self._statements)
@@ -39,7 +39,7 @@ class PsBlock(PsStructuralAstNode):
         self._statements[idx] = failing_cast(PsStructuralAstNode, c)
 
     def clone(self) -> PsBlock:
-        return PsBlock([stmt.clone() for stmt in self._statements])
+        return PsBlock([failing_cast(PsStructuralAstNode, stmt.clone()) for stmt in self._statements])
 
     @property
     def statements(self) -> list[PsStructuralAstNode]:
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 4fd09f879..2213320c8 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -26,6 +26,7 @@ from ..ast.structural import (
     PsDeclaration,
     PsExpression,
     PsSymbolExpr,
+    PsStructuralAstNode,
 )
 from ..ast.expressions import (
     PsBufferAcc,
@@ -107,7 +108,7 @@ class FreezeExpressions:
 
     def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
         if isinstance(obj, AssignmentCollection):
-            return PsBlock([self.visit(asm) for asm in obj.all_assignments])
+            return PsBlock([cast(PsStructuralAstNode, self.visit(asm)) for asm in obj.all_assignments])
         elif isinstance(obj, AssignmentBase):
             return cast(PsAssignment, self.visit(obj))
         elif isinstance(obj, _ExprLike):
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 0e6d314ac..3b6a2c18d 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -1,12 +1,12 @@
 from __future__ import annotations
 from dataclasses import dataclass
 
-from typing import Sequence
+from typing import Sequence, cast
 from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsPragma
+from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralAstNode
 from ..ast.expressions import PsExpression
 
 
@@ -57,7 +57,7 @@ class InsertPragmasAtLoops:
     def __call__(self, node: PsAstNode) -> PsAstNode:
         is_loop = isinstance(node, PsLoop)
         if is_loop:
-            node = PsBlock([node])
+            node = PsBlock([cast(PsLoop, node)])
 
         self.visit(node, Nesting(0))
 
@@ -72,7 +72,7 @@ class InsertPragmasAtLoops:
                 return
 
             case PsBlock(children):
-                new_children: list[PsAstNode] = []
+                new_children: list[PsStructuralAstNode] = []
                 for c in children:
                     if isinstance(c, PsLoop):
                         nest.has_inner_loops = True
diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py
index ab4401f9c..93484932d 100644
--- a/src/pystencils/backend/transformations/ast_vectorizer.py
+++ b/src/pystencils/backend/transformations/ast_vectorizer.py
@@ -18,6 +18,7 @@ from ..ast.structural import (
     PsAssignment,
     PsLoop,
     PsEmptyLeafMixIn,
+    PsStructuralAstNode,
 )
 from ..ast.expressions import (
     PsExpression,
@@ -273,7 +274,7 @@ class AstVectorizer:
 
         match node:
             case PsBlock(stmts):
-                return PsBlock([self.visit(n, vc) for n in stmts])
+                return PsBlock([cast(PsStructuralAstNode, self.visit(n, vc)) for n in stmts])
 
             case PsExpression():
                 return self.visit_expr(node, vc)
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
index f098d82df..02e406bc8 100644
--- a/src/pystencils/backend/transformations/eliminate_branches.py
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -1,7 +1,9 @@
+from typing import cast
+
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
 from ..ast.analysis import collect_undefined_symbols
-from ..ast.structural import PsLoop, PsBlock, PsConditional
+from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralAstNode
 from ..ast.expressions import (
     PsAnd,
     PsCast,
@@ -66,14 +68,14 @@ class EliminateBranches:
     def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode:
         match node:
             case PsLoop(_, _, _, _, body):
-                ec.enclosing_loops.append(node)
+                ec.enclosing_loops.append(cast(PsLoop, node))
                 self.visit(body, ec)
                 ec.enclosing_loops.pop()
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralAstNode] = []
                 for stmt in statements:
-                    statements_new.append(self.visit(stmt, ec))
+                    statements_new.append(cast(PsStructuralAstNode, self.visit(stmt, ec)))
                 node.statements = statements_new
 
             case PsConditional():
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index ab1cabc55..ea59e4f23 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -6,7 +6,7 @@ import numpy as np
 from ..kernelcreation import KernelCreationContext, Typifier
 
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsDeclaration
+from ..ast.structural import PsBlock, PsDeclaration, PsStructuralAstNode
 from ..ast.expressions import (
     PsExpression,
     PsConstantExpr,
@@ -144,7 +144,7 @@ class EliminateConstants:
             ]
 
             if not isinstance(node, PsBlock):
-                node = PsBlock(prepend_decls + [node])
+                node = PsBlock(prepend_decls + [cast(PsStructuralAstNode, node)])
             else:
                 node.children = prepend_decls + list(node.children)
 
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index f0e4cc9f1..7369b3ef0 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -2,7 +2,7 @@ from typing import cast
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment
+from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralAstNode
 from ..ast.expressions import (
     PsExpression,
     PsSymbolExpr,
@@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations:
         """Search the outermost loop and start the hoisting cascade there."""
         match node:
             case PsLoop():
-                temp_block = PsBlock([node])
+                temp_block = PsBlock([cast(PsLoop, node)])
                 temp_block = cast(PsBlock, self.visit(temp_block))
                 if temp_block.statements == [node]:
                     return node
@@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations:
                     return temp_block
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralAstNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations:
                 return
 
             case PsBlock(statements):
-                statements_new: list[PsAstNode] = []
+                statements_new: list[PsStructuralAstNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations:
         This method processes only statements of the given block, and any blocks directly nested inside it.
         It does not descend into control structures like conditionals and nested loops.
         """
-        statements_new: list[PsAstNode] = []
+        statements_new: list[PsStructuralAstNode] = []
 
         for node in block.statements:
             if isinstance(node, PsDeclaration):
-- 
GitLab


From 7ea2382e51410078a7813450dbe2292895bc1121 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 3 Mar 2025 16:15:56 +0100
Subject: [PATCH 3/7] Fix typecheck

---
 src/pystencils/backend/transformations/add_pragmas.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 3b6a2c18d..b7d66fbbd 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -91,8 +91,8 @@ class InsertPragmasAtLoops:
                 node.children = new_children
 
             case other:
-                for c in other.children:
-                    self.visit(c, nest)
+                for child in other.children:
+                    self.visit(child, nest)
 
 
 class AddOpenMP:
-- 
GitLab


From 216ef8bcc504647428b8f6c128ccf1aee6694ce8 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 3 Mar 2025 16:18:02 +0100
Subject: [PATCH 4/7] Rename newly introduced node to PsStructuralNode

---
 src/pystencils/backend/ast/structural.py      | 28 +++++++++----------
 .../backend/kernelcreation/freeze.py          |  4 +--
 .../backend/transformations/add_pragmas.py    |  4 +--
 .../backend/transformations/ast_vectorizer.py |  4 +--
 .../transformations/eliminate_branches.py     |  6 ++--
 .../transformations/eliminate_constants.py    |  4 +--
 .../hoist_loop_invariant_decls.py             |  8 +++---
 7 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 5966304e2..c25579029 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -11,17 +11,17 @@ from ..memory import PsSymbol
 from .util import failing_cast
 
 
-class PsStructuralAstNode(PsAstNode, ABC):
+class PsStructuralNode(PsAstNode, ABC):
     """Base class for structural nodes in the pystencils AST.
 
     This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
     """
 
 
-class PsBlock(PsStructuralAstNode):
+class PsBlock(PsStructuralNode):
     __match_args__ = ("statements",)
 
-    def __init__(self, cs: Iterable[PsStructuralAstNode]):
+    def __init__(self, cs: Iterable[PsStructuralNode]):
         self._statements = list(cs)
 
     @property
@@ -30,23 +30,23 @@ class PsBlock(PsStructuralAstNode):
 
     @children.setter
     def children(self, cs: Sequence[PsAstNode]):
-        self._statements = list([failing_cast(PsStructuralAstNode, c) for c in cs])
+        self._statements = list([failing_cast(PsStructuralNode, c) for c in cs])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
         return tuple(self._statements)
 
     def set_child(self, idx: int, c: PsAstNode):
-        self._statements[idx] = failing_cast(PsStructuralAstNode, c)
+        self._statements[idx] = failing_cast(PsStructuralNode, c)
 
     def clone(self) -> PsBlock:
-        return PsBlock([failing_cast(PsStructuralAstNode, stmt.clone()) for stmt in self._statements])
+        return PsBlock([failing_cast(PsStructuralNode, stmt.clone()) for stmt in self._statements])
 
     @property
-    def statements(self) -> list[PsStructuralAstNode]:
+    def statements(self) -> list[PsStructuralNode]:
         return self._statements
 
     @statements.setter
-    def statements(self, stm: Sequence[PsStructuralAstNode]):
+    def statements(self, stm: Sequence[PsStructuralNode]):
         self._statements = list(stm)
 
     def __repr__(self) -> str:
@@ -54,7 +54,7 @@ class PsBlock(PsStructuralAstNode):
         return f"PsBlock( {contents} )"
 
 
-class PsStatement(PsStructuralAstNode):
+class PsStatement(PsStructuralNode):
     __match_args__ = ("expression",)
 
     def __init__(self, expr: PsExpression):
@@ -80,7 +80,7 @@ class PsStatement(PsStructuralAstNode):
         self._expression = failing_cast(PsExpression, c)
 
 
-class PsAssignment(PsStructuralAstNode):
+class PsAssignment(PsStructuralNode):
     __match_args__ = (
         "lhs",
         "rhs",
@@ -166,7 +166,7 @@ class PsDeclaration(PsAssignment):
         return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})"
 
 
-class PsLoop(PsStructuralAstNode):
+class PsLoop(PsStructuralNode):
     __match_args__ = ("counter", "start", "stop", "step", "body")
 
     def __init__(
@@ -252,7 +252,7 @@ class PsLoop(PsStructuralAstNode):
                 assert False, "unreachable code"
 
 
-class PsConditional(PsStructuralAstNode):
+class PsConditional(PsStructuralNode):
     """Conditional branch"""
 
     __match_args__ = ("condition", "branch_true", "branch_false")
@@ -326,7 +326,7 @@ class PsEmptyLeafMixIn:
     pass
 
 
-class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode):
+class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     """A C/C++ preprocessor pragma.
 
     Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
@@ -354,7 +354,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode):
         return self._text == other._text
 
 
-class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode):
+class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     __match_args__ = ("lines",)
 
     def __init__(self, text: str) -> None:
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 2213320c8..b3ff5aefb 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -26,7 +26,7 @@ from ..ast.structural import (
     PsDeclaration,
     PsExpression,
     PsSymbolExpr,
-    PsStructuralAstNode,
+    PsStructuralNode,
 )
 from ..ast.expressions import (
     PsBufferAcc,
@@ -108,7 +108,7 @@ class FreezeExpressions:
 
     def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode:
         if isinstance(obj, AssignmentCollection):
-            return PsBlock([cast(PsStructuralAstNode, self.visit(asm)) for asm in obj.all_assignments])
+            return PsBlock([cast(PsStructuralNode, self.visit(asm)) for asm in obj.all_assignments])
         elif isinstance(obj, AssignmentBase):
             return cast(PsAssignment, self.visit(obj))
         elif isinstance(obj, _ExprLike):
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index b7d66fbbd..935ac38e3 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -6,7 +6,7 @@ from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralAstNode
+from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralNode
 from ..ast.expressions import PsExpression
 
 
@@ -72,7 +72,7 @@ class InsertPragmasAtLoops:
                 return
 
             case PsBlock(children):
-                new_children: list[PsStructuralAstNode] = []
+                new_children: list[PsStructuralNode] = []
                 for c in children:
                     if isinstance(c, PsLoop):
                         nest.has_inner_loops = True
diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py
index 93484932d..9621699d0 100644
--- a/src/pystencils/backend/transformations/ast_vectorizer.py
+++ b/src/pystencils/backend/transformations/ast_vectorizer.py
@@ -18,7 +18,7 @@ from ..ast.structural import (
     PsAssignment,
     PsLoop,
     PsEmptyLeafMixIn,
-    PsStructuralAstNode,
+    PsStructuralNode,
 )
 from ..ast.expressions import (
     PsExpression,
@@ -274,7 +274,7 @@ class AstVectorizer:
 
         match node:
             case PsBlock(stmts):
-                return PsBlock([cast(PsStructuralAstNode, self.visit(n, vc)) for n in stmts])
+                return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts])
 
             case PsExpression():
                 return self.visit_expr(node, vc)
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
index 02e406bc8..ca24e49b7 100644
--- a/src/pystencils/backend/transformations/eliminate_branches.py
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -3,7 +3,7 @@ from typing import cast
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
 from ..ast.analysis import collect_undefined_symbols
-from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralAstNode
+from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralNode
 from ..ast.expressions import (
     PsAnd,
     PsCast,
@@ -73,9 +73,9 @@ class EliminateBranches:
                 ec.enclosing_loops.pop()
 
             case PsBlock(statements):
-                statements_new: list[PsStructuralAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
-                    statements_new.append(cast(PsStructuralAstNode, self.visit(stmt, ec)))
+                    statements_new.append(cast(PsStructuralNode, self.visit(stmt, ec)))
                 node.statements = statements_new
 
             case PsConditional():
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index ea59e4f23..b66efe4f2 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -6,7 +6,7 @@ import numpy as np
 from ..kernelcreation import KernelCreationContext, Typifier
 
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsDeclaration, PsStructuralAstNode
+from ..ast.structural import PsBlock, PsDeclaration, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsConstantExpr,
@@ -144,7 +144,7 @@ class EliminateConstants:
             ]
 
             if not isinstance(node, PsBlock):
-                node = PsBlock(prepend_decls + [cast(PsStructuralAstNode, node)])
+                node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)])
             else:
                 node.children = prepend_decls + list(node.children)
 
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index 7369b3ef0..9637485dd 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -2,7 +2,7 @@ from typing import cast
 
 from ..kernelcreation import KernelCreationContext
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralAstNode
+from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsSymbolExpr,
@@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations:
                     return temp_block
 
             case PsBlock(statements):
-                statements_new: list[PsStructuralAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations:
                 return
 
             case PsBlock(statements):
-                statements_new: list[PsStructuralAstNode] = []
+                statements_new: list[PsStructuralNode] = []
                 for stmt in statements:
                     if isinstance(stmt, PsLoop):
                         loop = stmt
@@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations:
         This method processes only statements of the given block, and any blocks directly nested inside it.
         It does not descend into control structures like conditionals and nested loops.
         """
-        statements_new: list[PsStructuralAstNode] = []
+        statements_new: list[PsStructuralNode] = []
 
         for node in block.statements:
             if isinstance(node, PsDeclaration):
-- 
GitLab


From 8ceb678194386094f9ef8ddc8b8af74941135295 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 3 Mar 2025 16:34:40 +0100
Subject: [PATCH 5/7] Introduce _clone_structural function for PsStructural
 node (similar to PsExpression)

---
 src/pystencils/backend/ast/structural.py | 38 +++++++++++++++++-------
 1 file changed, 28 insertions(+), 10 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index c25579029..b0dace4e0 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from abc import ABC
+from abc import ABC, abstractmethod
 from typing import Iterable, Sequence, cast
 from types import NoneType
 
@@ -17,6 +17,24 @@ class PsStructuralNode(PsAstNode, ABC):
     This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
     """
 
+    def clone(self) -> PsStructuralNode:
+        """Clone this structure node.
+
+        .. note::
+            Subclasses of `PsStructuralNode` should not override this method,
+            but implement `_clone_structural` instead.
+            That implementation shall call `clone` on any of its children.
+        """
+        return self._clone_structural()
+
+    @abstractmethod
+    def _clone_structural(self) -> PsStructuralNode:
+        """Implementation of structural node cloning.
+
+        :meta public:
+        """
+        pass
+
 
 class PsBlock(PsStructuralNode):
     __match_args__ = ("statements",)
@@ -38,8 +56,8 @@ class PsBlock(PsStructuralNode):
     def set_child(self, idx: int, c: PsAstNode):
         self._statements[idx] = failing_cast(PsStructuralNode, c)
 
-    def clone(self) -> PsBlock:
-        return PsBlock([failing_cast(PsStructuralNode, stmt.clone()) for stmt in self._statements])
+    def _clone_structural(self) -> PsBlock:
+        return PsBlock([stmt._clone_structural() for stmt in self._statements])
 
     @property
     def statements(self) -> list[PsStructuralNode]:
@@ -68,7 +86,7 @@ class PsStatement(PsStructuralNode):
     def expression(self, expr: PsExpression):
         self._expression = expr
 
-    def clone(self) -> PsStatement:
+    def _clone_structural(self) -> PsStatement:
         return PsStatement(self._expression.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -110,7 +128,7 @@ class PsAssignment(PsStructuralNode):
     def rhs(self, expr: PsExpression):
         self._rhs = expr
 
-    def clone(self) -> PsAssignment:
+    def _clone_structural(self) -> PsAssignment:
         return PsAssignment(self._lhs.clone(), self._rhs.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -150,7 +168,7 @@ class PsDeclaration(PsAssignment):
     def declared_symbol(self) -> PsSymbol:
         return cast(PsSymbolExpr, self._lhs).symbol
 
-    def clone(self) -> PsDeclaration:
+    def _clone_structural(self) -> PsDeclaration:
         return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone())
 
     def set_child(self, idx: int, c: PsAstNode):
@@ -223,7 +241,7 @@ class PsLoop(PsStructuralNode):
     def body(self, block: PsBlock):
         self._body = block
 
-    def clone(self) -> PsLoop:
+    def _clone_structural(self) -> PsLoop:
         return PsLoop(
             self._ctr.clone(),
             self._start.clone(),
@@ -291,7 +309,7 @@ class PsConditional(PsStructuralNode):
     def branch_false(self, block: PsBlock | None):
         self._branch_false = block
 
-    def clone(self) -> PsConditional:
+    def _clone_structural(self) -> PsConditional:
         return PsConditional(
             self._condition.clone(),
             self._branch_true.clone(),
@@ -344,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     def text(self) -> str:
         return self._text
 
-    def clone(self) -> PsPragma:
+    def _clone_structural(self) -> PsPragma:
         return PsPragma(self.text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -369,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode):
     def lines(self) -> tuple[str, ...]:
         return self._lines
 
-    def clone(self) -> PsComment:
+    def _clone_structural(self) -> PsComment:
         return PsComment(self._text)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
-- 
GitLab


From 84b4c13bf6758f1cc664eebb647f57740d83e831 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Mon, 3 Mar 2025 16:39:59 +0100
Subject: [PATCH 6/7] Fix typecheck

---
 src/pystencils/backend/ast/structural.py                  | 6 +++---
 src/pystencils/backend/transformations/loop_vectorizer.py | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index b0dace4e0..5c8fca9ad 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -247,7 +247,7 @@ class PsLoop(PsStructuralNode):
             self._start.clone(),
             self._stop.clone(),
             self._step.clone(),
-            self._body.clone(),
+            self._body._clone_structural(),
         )
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -312,8 +312,8 @@ class PsConditional(PsStructuralNode):
     def _clone_structural(self) -> PsConditional:
         return PsConditional(
             self._condition.clone(),
-            self._branch_true.clone(),
-            self._branch_false.clone() if self._branch_false is not None else None,
+            self._branch_true._clone_structural(),
+            self._branch_false._clone_structural() if self._branch_false is not None else None,
         )
 
     def get_children(self) -> tuple[PsAstNode, ...]:
diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py
index e1e4fea50..6b518a30d 100644
--- a/src/pystencils/backend/transformations/loop_vectorizer.py
+++ b/src/pystencils/backend/transformations/loop_vectorizer.py
@@ -213,7 +213,7 @@ class LoopVectorizer:
 
                 trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr)
                 trailing_loop_body = substitute_symbols(
-                    loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)}
+                    loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)}
                 )
                 trailing_loop = PsLoop(
                     PsExpression.make(trailing_ctr),
-- 
GitLab


From 7991adb679382e0a6e7a4a9ea2f8bb3b0084729f Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 5 Mar 2025 12:04:59 +0100
Subject: [PATCH 7/7] remove various casts. clean up type annotations. fix one
 small bug.

---
 src/pystencils/backend/ast/structural.py           |  2 +-
 .../backend/transformations/add_pragmas.py         |  9 ++++-----
 .../backend/transformations/ast_vectorizer.py      | 14 +++++++++++++-
 .../backend/transformations/eliminate_branches.py  |  2 +-
 .../backend/transformations/eliminate_constants.py |  8 +++++++-
 .../transformations/hoist_loop_invariant_decls.py  |  2 +-
 .../backend/transformations/loop_vectorizer.py     |  2 +-
 src/pystencils/backend/transformations/rewrite.py  |  9 ++++++++-
 8 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 5c8fca9ad..31d8ea192 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -17,7 +17,7 @@ class PsStructuralNode(PsAstNode, ABC):
     This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from.
     """
 
-    def clone(self) -> PsStructuralNode:
+    def clone(self):
         """Clone this structure node.
 
         .. note::
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 935ac38e3..bd782422f 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 from dataclasses import dataclass
 
-from typing import Sequence, cast
+from typing import Sequence
 from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext
@@ -55,13 +55,12 @@ class InsertPragmasAtLoops:
             self._insertions[ins.loop_nesting_depth].append(ins)
 
     def __call__(self, node: PsAstNode) -> PsAstNode:
-        is_loop = isinstance(node, PsLoop)
-        if is_loop:
-            node = PsBlock([cast(PsLoop, node)])
+        if isinstance(node, PsLoop):
+            node = PsBlock([node])
 
         self.visit(node, Nesting(0))
 
-        if is_loop and len(node.children) == 1:
+        if isinstance(node, PsLoop) and len(node.children) == 1:
             node = node.children[0]
 
         return node
diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py
index 9621699d0..c793c424d 100644
--- a/src/pystencils/backend/transformations/ast_vectorizer.py
+++ b/src/pystencils/backend/transformations/ast_vectorizer.py
@@ -269,12 +269,24 @@ class AstVectorizer:
         """
         return self.visit(node, vc)
 
+    @overload
+    def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode:
+        pass
+
+    @overload
+    def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression:
+        pass
+
+    @overload
+    def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
+        pass
+
     def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
         """Vectorize a subtree."""
 
         match node:
             case PsBlock(stmts):
-                return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts])
+                return PsBlock([self.visit(n, vc) for n in stmts])
 
             case PsExpression():
                 return self.visit_expr(node, vc)
diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py
index ca24e49b7..69dd1dd11 100644
--- a/src/pystencils/backend/transformations/eliminate_branches.py
+++ b/src/pystencils/backend/transformations/eliminate_branches.py
@@ -68,7 +68,7 @@ class EliminateBranches:
     def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode:
         match node:
             case PsLoop(_, _, _, _, body):
-                ec.enclosing_loops.append(cast(PsLoop, node))
+                ec.enclosing_loops.append(node)
                 self.visit(body, ec)
                 ec.enclosing_loops.pop()
 
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index b66efe4f2..3a07cb56f 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -36,6 +36,7 @@ from ..ast.expressions import (
 )
 from ..ast.vector import PsVecBroadcast
 from ..ast.util import AstEqWrapper
+from ..exceptions import PsInternalCompilerError
 
 from ..constants import PsConstant
 from ..memory import PsSymbol
@@ -138,13 +139,18 @@ class EliminateConstants:
         node = self.visit(node, ecc)
 
         if ecc.extractions:
+            if not isinstance(node, PsStructuralNode):
+                raise PsInternalCompilerError(
+                    f"Cannot extract constant expressions from outermost node {node}"
+                )
+
             prepend_decls = [
                 PsDeclaration(PsExpression.make(symb), expr)
                 for symb, expr in ecc.extractions
             ]
 
             if not isinstance(node, PsBlock):
-                node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)])
+                node = PsBlock(prepend_decls + [node])
             else:
                 node.children = prepend_decls + list(node.children)
 
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index 9637485dd..f7fe81ad7 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations:
         """Search the outermost loop and start the hoisting cascade there."""
         match node:
             case PsLoop():
-                temp_block = PsBlock([cast(PsLoop, node)])
+                temp_block = PsBlock([node])
                 temp_block = cast(PsBlock, self.visit(temp_block))
                 if temp_block.statements == [node]:
                     return node
diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py
index 6b518a30d..e1e4fea50 100644
--- a/src/pystencils/backend/transformations/loop_vectorizer.py
+++ b/src/pystencils/backend/transformations/loop_vectorizer.py
@@ -213,7 +213,7 @@ class LoopVectorizer:
 
                 trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr)
                 trailing_loop_body = substitute_symbols(
-                    loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)}
+                    loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)}
                 )
                 trailing_loop = PsLoop(
                     PsExpression.make(trailing_ctr),
diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py
index 59241c295..8dff9e45e 100644
--- a/src/pystencils/backend/transformations/rewrite.py
+++ b/src/pystencils/backend/transformations/rewrite.py
@@ -2,7 +2,7 @@ from typing import overload
 
 from ..memory import PsSymbol
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock
+from ..ast.structural import PsStructuralNode, PsBlock
 from ..ast.expressions import PsExpression, PsSymbolExpr
 
 
@@ -18,6 +18,13 @@ def substitute_symbols(
     pass
 
 
+@overload
+def substitute_symbols(
+    node: PsStructuralNode, subs: dict[PsSymbol, PsExpression]
+) -> PsStructuralNode:
+    pass
+
+
 @overload
 def substitute_symbols(
     node: PsAstNode, subs: dict[PsSymbol, PsExpression]
-- 
GitLab