diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 5c8fca9adf9546683f3f7489eb137a00f1523b9d..31d8ea192269a9a9947457814ff5e58d63f61c14 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -17,7 +17,7 @@ class PsStructuralNode(PsAstNode, ABC): This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. """ - def clone(self) -> PsStructuralNode: + def clone(self): """Clone this structure node. .. note:: diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 935ac38e32d6cc77276e607f38e4b21d8062a70c..bd782422f1fa80b96ec7cf69473fda2b1f45c3d6 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, cast +from typing import Sequence from collections import defaultdict from ..kernelcreation import KernelCreationContext @@ -55,13 +55,12 @@ class InsertPragmasAtLoops: self._insertions[ins.loop_nesting_depth].append(ins) def __call__(self, node: PsAstNode) -> PsAstNode: - is_loop = isinstance(node, PsLoop) - if is_loop: - node = PsBlock([cast(PsLoop, node)]) + if isinstance(node, PsLoop): + node = PsBlock([node]) self.visit(node, Nesting(0)) - if is_loop and len(node.children) == 1: + if isinstance(node, PsLoop) and len(node.children) == 1: node = node.children[0] return node diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index 9621699d03c5ee5a06f056da961c4857f015ee04..c793c424d2417cbbdcc0cf3782e696c7c9226bb6 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -269,12 +269,24 @@ class AstVectorizer: """ return self.visit(node, vc) + @overload + def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode: + pass + + @overload + def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression: + pass + + @overload + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + pass + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: """Vectorize a subtree.""" match node: case PsBlock(stmts): - return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts]) + return PsBlock([self.visit(n, vc) for n in stmts]) case PsExpression(): return self.visit_expr(node, vc) diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index ca24e49b774d58007c5652a72aaa0ec4d8f4c9f6..69dd1dd11d726e597c15ece772846ba8cd84acba 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -68,7 +68,7 @@ class EliminateBranches: def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: match node: case PsLoop(_, _, _, _, body): - ec.enclosing_loops.append(cast(PsLoop, node)) + ec.enclosing_loops.append(node) self.visit(body, ec) ec.enclosing_loops.pop() diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index b66efe4f25b3ef37019595abf296e42c3343f272..3a07cb56fcb8f1c60107b5b1883c679191429e7e 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -36,6 +36,7 @@ from ..ast.expressions import ( ) from ..ast.vector import PsVecBroadcast from ..ast.util import AstEqWrapper +from ..exceptions import PsInternalCompilerError from ..constants import PsConstant from ..memory import PsSymbol @@ -138,13 +139,18 @@ class EliminateConstants: node = self.visit(node, ecc) if ecc.extractions: + if not isinstance(node, PsStructuralNode): + raise PsInternalCompilerError( + f"Cannot extract constant expressions from outermost node {node}" + ) + prepend_decls = [ PsDeclaration(PsExpression.make(symb), expr) for symb, expr in ecc.extractions ] if not isinstance(node, PsBlock): - node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)]) + node = PsBlock(prepend_decls + [node]) else: node.children = prepend_decls + list(node.children) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 9637485dde4cc296cedbc0e24329949fda027292..f7fe81ad736981bee6f38427fbd4face73f0c455 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations: """Search the outermost loop and start the hoisting cascade there.""" match node: case PsLoop(): - temp_block = PsBlock([cast(PsLoop, node)]) + temp_block = PsBlock([node]) temp_block = cast(PsBlock, self.visit(temp_block)) if temp_block.statements == [node]: return node diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index 6b518a30de31acac67c061ead1683c1d6ab06816..e1e4fea502c08de86e13de5e3c251f1b7a7d0ee6 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -213,7 +213,7 @@ class LoopVectorizer: trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr) trailing_loop_body = substitute_symbols( - loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)} + loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)} ) trailing_loop = PsLoop( PsExpression.make(trailing_ctr), diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py index 59241c295f42eeaf60f4cd03a5138214fdbd6c50..8dff9e45ec283fc6c3712c2e77ff56a9b2aaeae5 100644 --- a/src/pystencils/backend/transformations/rewrite.py +++ b/src/pystencils/backend/transformations/rewrite.py @@ -2,7 +2,7 @@ from typing import overload from ..memory import PsSymbol from ..ast import PsAstNode -from ..ast.structural import PsBlock +from ..ast.structural import PsStructuralNode, PsBlock from ..ast.expressions import PsExpression, PsSymbolExpr @@ -18,6 +18,13 @@ def substitute_symbols( pass +@overload +def substitute_symbols( + node: PsStructuralNode, subs: dict[PsSymbol, PsExpression] +) -> PsStructuralNode: + pass + + @overload def substitute_symbols( node: PsAstNode, subs: dict[PsSymbol, PsExpression]