From 65f310985cdba73caee4f8b0d4c0cc50850174c0 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Fri, 21 Feb 2025 19:57:24 +0100 Subject: [PATCH 1/7] Introduce structural ast node trait and employ in structural.py --- src/pystencils/backend/ast/structural.py | 31 +++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 2c79f4f46..98ec72039 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -1,4 +1,6 @@ from __future__ import annotations + +from abc import ABC from typing import Iterable, Sequence, cast from types import NoneType @@ -9,10 +11,17 @@ from ..memory import PsSymbol from .util import failing_cast -class PsBlock(PsAstNode): +class PsStructuralAstNode(PsAstNode, ABC): + """Base class for structural nodes in the pystencils AST. + + This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. + """ + + +class PsBlock(PsStructuralAstNode): __match_args__ = ("statements",) - def __init__(self, cs: Iterable[PsAstNode]): + def __init__(self, cs: Iterable[PsStructuralAstNode]): self._statements = list(cs) @property @@ -27,17 +36,17 @@ class PsBlock(PsAstNode): return tuple(self._statements) def set_child(self, idx: int, c: PsAstNode): - self._statements[idx] = c + self._statements[idx] = failing_cast(PsStructuralAstNode, c) def clone(self) -> PsBlock: return PsBlock([stmt.clone() for stmt in self._statements]) @property - def statements(self) -> list[PsAstNode]: + def statements(self) -> list[PsStructuralAstNode]: return self._statements @statements.setter - def statements(self, stm: Sequence[PsAstNode]): + def statements(self, stm: Sequence[PsStructuralAstNode]): self._statements = list(stm) def __repr__(self) -> str: @@ -45,7 +54,7 @@ class PsBlock(PsAstNode): return f"PsBlock( {contents} )" -class PsStatement(PsAstNode): +class PsStatement(PsStructuralAstNode): __match_args__ = ("expression",) def __init__(self, expr: PsExpression): @@ -71,7 +80,7 @@ class PsStatement(PsAstNode): self._expression = failing_cast(PsExpression, c) -class PsAssignment(PsAstNode): +class PsAssignment(PsStructuralAstNode): __match_args__ = ( "lhs", "rhs", @@ -157,7 +166,7 @@ class PsDeclaration(PsAssignment): return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})" -class PsLoop(PsAstNode): +class PsLoop(PsStructuralAstNode): __match_args__ = ("counter", "start", "stop", "step", "body") def __init__( @@ -243,7 +252,7 @@ class PsLoop(PsAstNode): assert False, "unreachable code" -class PsConditional(PsAstNode): +class PsConditional(PsStructuralAstNode): """Conditional branch""" __match_args__ = ("condition", "branch_true", "branch_false") @@ -317,7 +326,7 @@ class PsEmptyLeafMixIn: pass -class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): +class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode): """A C/C++ preprocessor pragma. Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. @@ -345,7 +354,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): return self._text == other._text -class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): +class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode): __match_args__ = ("lines",) def __init__(self, text: str) -> None: -- GitLab From a4b86fdd0611be29ed7c907ce41f46990bbef43e Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Thu, 27 Feb 2025 17:32:24 +0100 Subject: [PATCH 2/7] Try fixing typecheck for newly introduced structural ast nodes --- src/pystencils/backend/ast/structural.py | 4 ++-- src/pystencils/backend/kernelcreation/freeze.py | 3 ++- src/pystencils/backend/transformations/add_pragmas.py | 8 ++++---- .../backend/transformations/ast_vectorizer.py | 3 ++- .../backend/transformations/eliminate_branches.py | 10 ++++++---- .../backend/transformations/eliminate_constants.py | 4 ++-- .../transformations/hoist_loop_invariant_decls.py | 10 +++++----- 7 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 98ec72039..5966304e2 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -30,7 +30,7 @@ class PsBlock(PsStructuralAstNode): @children.setter def children(self, cs: Sequence[PsAstNode]): - self._statements = list(cs) + self._statements = list([failing_cast(PsStructuralAstNode, c) for c in cs]) def get_children(self) -> tuple[PsAstNode, ...]: return tuple(self._statements) @@ -39,7 +39,7 @@ class PsBlock(PsStructuralAstNode): self._statements[idx] = failing_cast(PsStructuralAstNode, c) def clone(self) -> PsBlock: - return PsBlock([stmt.clone() for stmt in self._statements]) + return PsBlock([failing_cast(PsStructuralAstNode, stmt.clone()) for stmt in self._statements]) @property def statements(self) -> list[PsStructuralAstNode]: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 4fd09f879..2213320c8 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -26,6 +26,7 @@ from ..ast.structural import ( PsDeclaration, PsExpression, PsSymbolExpr, + PsStructuralAstNode, ) from ..ast.expressions import ( PsBufferAcc, @@ -107,7 +108,7 @@ class FreezeExpressions: def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode: if isinstance(obj, AssignmentCollection): - return PsBlock([self.visit(asm) for asm in obj.all_assignments]) + return PsBlock([cast(PsStructuralAstNode, self.visit(asm)) for asm in obj.all_assignments]) elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) elif isinstance(obj, _ExprLike): diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 0e6d314ac..3b6a2c18d 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -1,12 +1,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence +from typing import Sequence, cast from collections import defaultdict from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsPragma +from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralAstNode from ..ast.expressions import PsExpression @@ -57,7 +57,7 @@ class InsertPragmasAtLoops: def __call__(self, node: PsAstNode) -> PsAstNode: is_loop = isinstance(node, PsLoop) if is_loop: - node = PsBlock([node]) + node = PsBlock([cast(PsLoop, node)]) self.visit(node, Nesting(0)) @@ -72,7 +72,7 @@ class InsertPragmasAtLoops: return case PsBlock(children): - new_children: list[PsAstNode] = [] + new_children: list[PsStructuralAstNode] = [] for c in children: if isinstance(c, PsLoop): nest.has_inner_loops = True diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index ab4401f9c..93484932d 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -18,6 +18,7 @@ from ..ast.structural import ( PsAssignment, PsLoop, PsEmptyLeafMixIn, + PsStructuralAstNode, ) from ..ast.expressions import ( PsExpression, @@ -273,7 +274,7 @@ class AstVectorizer: match node: case PsBlock(stmts): - return PsBlock([self.visit(n, vc) for n in stmts]) + return PsBlock([cast(PsStructuralAstNode, self.visit(n, vc)) for n in stmts]) case PsExpression(): return self.visit_expr(node, vc) diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index f098d82df..02e406bc8 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -1,7 +1,9 @@ +from typing import cast + from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode from ..ast.analysis import collect_undefined_symbols -from ..ast.structural import PsLoop, PsBlock, PsConditional +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralAstNode from ..ast.expressions import ( PsAnd, PsCast, @@ -66,14 +68,14 @@ class EliminateBranches: def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: match node: case PsLoop(_, _, _, _, body): - ec.enclosing_loops.append(node) + ec.enclosing_loops.append(cast(PsLoop, node)) self.visit(body, ec) ec.enclosing_loops.pop() case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralAstNode] = [] for stmt in statements: - statements_new.append(self.visit(stmt, ec)) + statements_new.append(cast(PsStructuralAstNode, self.visit(stmt, ec))) node.statements = statements_new case PsConditional(): diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index ab1cabc55..ea59e4f23 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -6,7 +6,7 @@ import numpy as np from ..kernelcreation import KernelCreationContext, Typifier from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsDeclaration +from ..ast.structural import PsBlock, PsDeclaration, PsStructuralAstNode from ..ast.expressions import ( PsExpression, PsConstantExpr, @@ -144,7 +144,7 @@ class EliminateConstants: ] if not isinstance(node, PsBlock): - node = PsBlock(prepend_decls + [node]) + node = PsBlock(prepend_decls + [cast(PsStructuralAstNode, node)]) else: node.children = prepend_decls + list(node.children) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index f0e4cc9f1..7369b3ef0 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -2,7 +2,7 @@ from typing import cast from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment +from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralAstNode from ..ast.expressions import ( PsExpression, PsSymbolExpr, @@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations: """Search the outermost loop and start the hoisting cascade there.""" match node: case PsLoop(): - temp_block = PsBlock([node]) + temp_block = PsBlock([cast(PsLoop, node)]) temp_block = cast(PsBlock, self.visit(temp_block)) if temp_block.statements == [node]: return node @@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations: return temp_block case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralAstNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations: return case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralAstNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations: This method processes only statements of the given block, and any blocks directly nested inside it. It does not descend into control structures like conditionals and nested loops. """ - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralAstNode] = [] for node in block.statements: if isinstance(node, PsDeclaration): -- GitLab From 7ea2382e51410078a7813450dbe2292895bc1121 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 3 Mar 2025 16:15:56 +0100 Subject: [PATCH 3/7] Fix typecheck --- src/pystencils/backend/transformations/add_pragmas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 3b6a2c18d..b7d66fbbd 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -91,8 +91,8 @@ class InsertPragmasAtLoops: node.children = new_children case other: - for c in other.children: - self.visit(c, nest) + for child in other.children: + self.visit(child, nest) class AddOpenMP: -- GitLab From 216ef8bcc504647428b8f6c128ccf1aee6694ce8 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 3 Mar 2025 16:18:02 +0100 Subject: [PATCH 4/7] Rename newly introduced node to PsStructuralNode --- src/pystencils/backend/ast/structural.py | 28 +++++++++---------- .../backend/kernelcreation/freeze.py | 4 +-- .../backend/transformations/add_pragmas.py | 4 +-- .../backend/transformations/ast_vectorizer.py | 4 +-- .../transformations/eliminate_branches.py | 6 ++-- .../transformations/eliminate_constants.py | 4 +-- .../hoist_loop_invariant_decls.py | 8 +++--- 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 5966304e2..c25579029 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -11,17 +11,17 @@ from ..memory import PsSymbol from .util import failing_cast -class PsStructuralAstNode(PsAstNode, ABC): +class PsStructuralNode(PsAstNode, ABC): """Base class for structural nodes in the pystencils AST. This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. """ -class PsBlock(PsStructuralAstNode): +class PsBlock(PsStructuralNode): __match_args__ = ("statements",) - def __init__(self, cs: Iterable[PsStructuralAstNode]): + def __init__(self, cs: Iterable[PsStructuralNode]): self._statements = list(cs) @property @@ -30,23 +30,23 @@ class PsBlock(PsStructuralAstNode): @children.setter def children(self, cs: Sequence[PsAstNode]): - self._statements = list([failing_cast(PsStructuralAstNode, c) for c in cs]) + self._statements = list([failing_cast(PsStructuralNode, c) for c in cs]) def get_children(self) -> tuple[PsAstNode, ...]: return tuple(self._statements) def set_child(self, idx: int, c: PsAstNode): - self._statements[idx] = failing_cast(PsStructuralAstNode, c) + self._statements[idx] = failing_cast(PsStructuralNode, c) def clone(self) -> PsBlock: - return PsBlock([failing_cast(PsStructuralAstNode, stmt.clone()) for stmt in self._statements]) + return PsBlock([failing_cast(PsStructuralNode, stmt.clone()) for stmt in self._statements]) @property - def statements(self) -> list[PsStructuralAstNode]: + def statements(self) -> list[PsStructuralNode]: return self._statements @statements.setter - def statements(self, stm: Sequence[PsStructuralAstNode]): + def statements(self, stm: Sequence[PsStructuralNode]): self._statements = list(stm) def __repr__(self) -> str: @@ -54,7 +54,7 @@ class PsBlock(PsStructuralAstNode): return f"PsBlock( {contents} )" -class PsStatement(PsStructuralAstNode): +class PsStatement(PsStructuralNode): __match_args__ = ("expression",) def __init__(self, expr: PsExpression): @@ -80,7 +80,7 @@ class PsStatement(PsStructuralAstNode): self._expression = failing_cast(PsExpression, c) -class PsAssignment(PsStructuralAstNode): +class PsAssignment(PsStructuralNode): __match_args__ = ( "lhs", "rhs", @@ -166,7 +166,7 @@ class PsDeclaration(PsAssignment): return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})" -class PsLoop(PsStructuralAstNode): +class PsLoop(PsStructuralNode): __match_args__ = ("counter", "start", "stop", "step", "body") def __init__( @@ -252,7 +252,7 @@ class PsLoop(PsStructuralAstNode): assert False, "unreachable code" -class PsConditional(PsStructuralAstNode): +class PsConditional(PsStructuralNode): """Conditional branch""" __match_args__ = ("condition", "branch_true", "branch_false") @@ -326,7 +326,7 @@ class PsEmptyLeafMixIn: pass -class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode): +class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): """A C/C++ preprocessor pragma. Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. @@ -354,7 +354,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode): return self._text == other._text -class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralAstNode): +class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): __match_args__ = ("lines",) def __init__(self, text: str) -> None: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 2213320c8..b3ff5aefb 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -26,7 +26,7 @@ from ..ast.structural import ( PsDeclaration, PsExpression, PsSymbolExpr, - PsStructuralAstNode, + PsStructuralNode, ) from ..ast.expressions import ( PsBufferAcc, @@ -108,7 +108,7 @@ class FreezeExpressions: def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode: if isinstance(obj, AssignmentCollection): - return PsBlock([cast(PsStructuralAstNode, self.visit(asm)) for asm in obj.all_assignments]) + return PsBlock([cast(PsStructuralNode, self.visit(asm)) for asm in obj.all_assignments]) elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) elif isinstance(obj, _ExprLike): diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index b7d66fbbd..935ac38e3 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -6,7 +6,7 @@ from collections import defaultdict from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralAstNode +from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralNode from ..ast.expressions import PsExpression @@ -72,7 +72,7 @@ class InsertPragmasAtLoops: return case PsBlock(children): - new_children: list[PsStructuralAstNode] = [] + new_children: list[PsStructuralNode] = [] for c in children: if isinstance(c, PsLoop): nest.has_inner_loops = True diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index 93484932d..9621699d0 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -18,7 +18,7 @@ from ..ast.structural import ( PsAssignment, PsLoop, PsEmptyLeafMixIn, - PsStructuralAstNode, + PsStructuralNode, ) from ..ast.expressions import ( PsExpression, @@ -274,7 +274,7 @@ class AstVectorizer: match node: case PsBlock(stmts): - return PsBlock([cast(PsStructuralAstNode, self.visit(n, vc)) for n in stmts]) + return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts]) case PsExpression(): return self.visit_expr(node, vc) diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index 02e406bc8..ca24e49b7 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -3,7 +3,7 @@ from typing import cast from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode from ..ast.analysis import collect_undefined_symbols -from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralAstNode +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralNode from ..ast.expressions import ( PsAnd, PsCast, @@ -73,9 +73,9 @@ class EliminateBranches: ec.enclosing_loops.pop() case PsBlock(statements): - statements_new: list[PsStructuralAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: - statements_new.append(cast(PsStructuralAstNode, self.visit(stmt, ec))) + statements_new.append(cast(PsStructuralNode, self.visit(stmt, ec))) node.statements = statements_new case PsConditional(): diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index ea59e4f23..b66efe4f2 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -6,7 +6,7 @@ import numpy as np from ..kernelcreation import KernelCreationContext, Typifier from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsDeclaration, PsStructuralAstNode +from ..ast.structural import PsBlock, PsDeclaration, PsStructuralNode from ..ast.expressions import ( PsExpression, PsConstantExpr, @@ -144,7 +144,7 @@ class EliminateConstants: ] if not isinstance(node, PsBlock): - node = PsBlock(prepend_decls + [cast(PsStructuralAstNode, node)]) + node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)]) else: node.children = prepend_decls + list(node.children) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 7369b3ef0..9637485dd 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -2,7 +2,7 @@ from typing import cast from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralAstNode +from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralNode from ..ast.expressions import ( PsExpression, PsSymbolExpr, @@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations: return temp_block case PsBlock(statements): - statements_new: list[PsStructuralAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations: return case PsBlock(statements): - statements_new: list[PsStructuralAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations: This method processes only statements of the given block, and any blocks directly nested inside it. It does not descend into control structures like conditionals and nested loops. """ - statements_new: list[PsStructuralAstNode] = [] + statements_new: list[PsStructuralNode] = [] for node in block.statements: if isinstance(node, PsDeclaration): -- GitLab From 8ceb678194386094f9ef8ddc8b8af74941135295 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 3 Mar 2025 16:34:40 +0100 Subject: [PATCH 5/7] Introduce _clone_structural function for PsStructural node (similar to PsExpression) --- src/pystencils/backend/ast/structural.py | 38 +++++++++++++++++------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index c25579029..b0dace4e0 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from typing import Iterable, Sequence, cast from types import NoneType @@ -17,6 +17,24 @@ class PsStructuralNode(PsAstNode, ABC): This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. """ + def clone(self) -> PsStructuralNode: + """Clone this structure node. + + .. note:: + Subclasses of `PsStructuralNode` should not override this method, + but implement `_clone_structural` instead. + That implementation shall call `clone` on any of its children. + """ + return self._clone_structural() + + @abstractmethod + def _clone_structural(self) -> PsStructuralNode: + """Implementation of structural node cloning. + + :meta public: + """ + pass + class PsBlock(PsStructuralNode): __match_args__ = ("statements",) @@ -38,8 +56,8 @@ class PsBlock(PsStructuralNode): def set_child(self, idx: int, c: PsAstNode): self._statements[idx] = failing_cast(PsStructuralNode, c) - def clone(self) -> PsBlock: - return PsBlock([failing_cast(PsStructuralNode, stmt.clone()) for stmt in self._statements]) + def _clone_structural(self) -> PsBlock: + return PsBlock([stmt._clone_structural() for stmt in self._statements]) @property def statements(self) -> list[PsStructuralNode]: @@ -68,7 +86,7 @@ class PsStatement(PsStructuralNode): def expression(self, expr: PsExpression): self._expression = expr - def clone(self) -> PsStatement: + def _clone_structural(self) -> PsStatement: return PsStatement(self._expression.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -110,7 +128,7 @@ class PsAssignment(PsStructuralNode): def rhs(self, expr: PsExpression): self._rhs = expr - def clone(self) -> PsAssignment: + def _clone_structural(self) -> PsAssignment: return PsAssignment(self._lhs.clone(), self._rhs.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -150,7 +168,7 @@ class PsDeclaration(PsAssignment): def declared_symbol(self) -> PsSymbol: return cast(PsSymbolExpr, self._lhs).symbol - def clone(self) -> PsDeclaration: + def _clone_structural(self) -> PsDeclaration: return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) def set_child(self, idx: int, c: PsAstNode): @@ -223,7 +241,7 @@ class PsLoop(PsStructuralNode): def body(self, block: PsBlock): self._body = block - def clone(self) -> PsLoop: + def _clone_structural(self) -> PsLoop: return PsLoop( self._ctr.clone(), self._start.clone(), @@ -291,7 +309,7 @@ class PsConditional(PsStructuralNode): def branch_false(self, block: PsBlock | None): self._branch_false = block - def clone(self) -> PsConditional: + def _clone_structural(self) -> PsConditional: return PsConditional( self._condition.clone(), self._branch_true.clone(), @@ -344,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): def text(self) -> str: return self._text - def clone(self) -> PsPragma: + def _clone_structural(self) -> PsPragma: return PsPragma(self.text) def structurally_equal(self, other: PsAstNode) -> bool: @@ -369,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): def lines(self) -> tuple[str, ...]: return self._lines - def clone(self) -> PsComment: + def _clone_structural(self) -> PsComment: return PsComment(self._text) def structurally_equal(self, other: PsAstNode) -> bool: -- GitLab From 84b4c13bf6758f1cc664eebb647f57740d83e831 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 3 Mar 2025 16:39:59 +0100 Subject: [PATCH 6/7] Fix typecheck --- src/pystencils/backend/ast/structural.py | 6 +++--- src/pystencils/backend/transformations/loop_vectorizer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index b0dace4e0..5c8fca9ad 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -247,7 +247,7 @@ class PsLoop(PsStructuralNode): self._start.clone(), self._stop.clone(), self._step.clone(), - self._body.clone(), + self._body._clone_structural(), ) def get_children(self) -> tuple[PsAstNode, ...]: @@ -312,8 +312,8 @@ class PsConditional(PsStructuralNode): def _clone_structural(self) -> PsConditional: return PsConditional( self._condition.clone(), - self._branch_true.clone(), - self._branch_false.clone() if self._branch_false is not None else None, + self._branch_true._clone_structural(), + self._branch_false._clone_structural() if self._branch_false is not None else None, ) def get_children(self) -> tuple[PsAstNode, ...]: diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index e1e4fea50..6b518a30d 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -213,7 +213,7 @@ class LoopVectorizer: trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr) trailing_loop_body = substitute_symbols( - loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)} + loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)} ) trailing_loop = PsLoop( PsExpression.make(trailing_ctr), -- GitLab From 7991adb679382e0a6e7a4a9ea2f8bb3b0084729f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 5 Mar 2025 12:04:59 +0100 Subject: [PATCH 7/7] remove various casts. clean up type annotations. fix one small bug. --- src/pystencils/backend/ast/structural.py | 2 +- .../backend/transformations/add_pragmas.py | 9 ++++----- .../backend/transformations/ast_vectorizer.py | 14 +++++++++++++- .../backend/transformations/eliminate_branches.py | 2 +- .../backend/transformations/eliminate_constants.py | 8 +++++++- .../transformations/hoist_loop_invariant_decls.py | 2 +- .../backend/transformations/loop_vectorizer.py | 2 +- src/pystencils/backend/transformations/rewrite.py | 9 ++++++++- 8 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 5c8fca9ad..31d8ea192 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -17,7 +17,7 @@ class PsStructuralNode(PsAstNode, ABC): This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. """ - def clone(self) -> PsStructuralNode: + def clone(self): """Clone this structure node. .. note:: diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 935ac38e3..bd782422f 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, cast +from typing import Sequence from collections import defaultdict from ..kernelcreation import KernelCreationContext @@ -55,13 +55,12 @@ class InsertPragmasAtLoops: self._insertions[ins.loop_nesting_depth].append(ins) def __call__(self, node: PsAstNode) -> PsAstNode: - is_loop = isinstance(node, PsLoop) - if is_loop: - node = PsBlock([cast(PsLoop, node)]) + if isinstance(node, PsLoop): + node = PsBlock([node]) self.visit(node, Nesting(0)) - if is_loop and len(node.children) == 1: + if isinstance(node, PsLoop) and len(node.children) == 1: node = node.children[0] return node diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index 9621699d0..c793c424d 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -269,12 +269,24 @@ class AstVectorizer: """ return self.visit(node, vc) + @overload + def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode: + pass + + @overload + def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression: + pass + + @overload + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + pass + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: """Vectorize a subtree.""" match node: case PsBlock(stmts): - return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts]) + return PsBlock([self.visit(n, vc) for n in stmts]) case PsExpression(): return self.visit_expr(node, vc) diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index ca24e49b7..69dd1dd11 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -68,7 +68,7 @@ class EliminateBranches: def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: match node: case PsLoop(_, _, _, _, body): - ec.enclosing_loops.append(cast(PsLoop, node)) + ec.enclosing_loops.append(node) self.visit(body, ec) ec.enclosing_loops.pop() diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index b66efe4f2..3a07cb56f 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -36,6 +36,7 @@ from ..ast.expressions import ( ) from ..ast.vector import PsVecBroadcast from ..ast.util import AstEqWrapper +from ..exceptions import PsInternalCompilerError from ..constants import PsConstant from ..memory import PsSymbol @@ -138,13 +139,18 @@ class EliminateConstants: node = self.visit(node, ecc) if ecc.extractions: + if not isinstance(node, PsStructuralNode): + raise PsInternalCompilerError( + f"Cannot extract constant expressions from outermost node {node}" + ) + prepend_decls = [ PsDeclaration(PsExpression.make(symb), expr) for symb, expr in ecc.extractions ] if not isinstance(node, PsBlock): - node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)]) + node = PsBlock(prepend_decls + [node]) else: node.children = prepend_decls + list(node.children) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 9637485dd..f7fe81ad7 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations: """Search the outermost loop and start the hoisting cascade there.""" match node: case PsLoop(): - temp_block = PsBlock([cast(PsLoop, node)]) + temp_block = PsBlock([node]) temp_block = cast(PsBlock, self.visit(temp_block)) if temp_block.statements == [node]: return node diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index 6b518a30d..e1e4fea50 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -213,7 +213,7 @@ class LoopVectorizer: trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr) trailing_loop_body = substitute_symbols( - loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)} + loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)} ) trailing_loop = PsLoop( PsExpression.make(trailing_ctr), diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py index 59241c295..8dff9e45e 100644 --- a/src/pystencils/backend/transformations/rewrite.py +++ b/src/pystencils/backend/transformations/rewrite.py @@ -2,7 +2,7 @@ from typing import overload from ..memory import PsSymbol from ..ast import PsAstNode -from ..ast.structural import PsBlock +from ..ast.structural import PsStructuralNode, PsBlock from ..ast.expressions import PsExpression, PsSymbolExpr @@ -18,6 +18,13 @@ def substitute_symbols( pass +@overload +def substitute_symbols( + node: PsStructuralNode, subs: dict[PsSymbol, PsExpression] +) -> PsStructuralNode: + pass + + @overload def substitute_symbols( node: PsAstNode, subs: dict[PsSymbol, PsExpression] -- GitLab