From 7991adb679382e0a6e7a4a9ea2f8bb3b0084729f Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 5 Mar 2025 12:04:59 +0100 Subject: [PATCH] remove various casts. clean up type annotations. fix one small bug. --- src/pystencils/backend/ast/structural.py | 2 +- .../backend/transformations/add_pragmas.py | 9 ++++----- .../backend/transformations/ast_vectorizer.py | 14 +++++++++++++- .../backend/transformations/eliminate_branches.py | 2 +- .../backend/transformations/eliminate_constants.py | 8 +++++++- .../transformations/hoist_loop_invariant_decls.py | 2 +- .../backend/transformations/loop_vectorizer.py | 2 +- src/pystencils/backend/transformations/rewrite.py | 9 ++++++++- 8 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 5c8fca9ad..31d8ea192 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -17,7 +17,7 @@ class PsStructuralNode(PsAstNode, ABC): This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. """ - def clone(self) -> PsStructuralNode: + def clone(self): """Clone this structure node. .. note:: diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 935ac38e3..bd782422f 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, cast +from typing import Sequence from collections import defaultdict from ..kernelcreation import KernelCreationContext @@ -55,13 +55,12 @@ class InsertPragmasAtLoops: self._insertions[ins.loop_nesting_depth].append(ins) def __call__(self, node: PsAstNode) -> PsAstNode: - is_loop = isinstance(node, PsLoop) - if is_loop: - node = PsBlock([cast(PsLoop, node)]) + if isinstance(node, PsLoop): + node = PsBlock([node]) self.visit(node, Nesting(0)) - if is_loop and len(node.children) == 1: + if isinstance(node, PsLoop) and len(node.children) == 1: node = node.children[0] return node diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index 9621699d0..c793c424d 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -269,12 +269,24 @@ class AstVectorizer: """ return self.visit(node, vc) + @overload + def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode: + pass + + @overload + def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression: + pass + + @overload + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + pass + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: """Vectorize a subtree.""" match node: case PsBlock(stmts): - return PsBlock([cast(PsStructuralNode, self.visit(n, vc)) for n in stmts]) + return PsBlock([self.visit(n, vc) for n in stmts]) case PsExpression(): return self.visit_expr(node, vc) diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index ca24e49b7..69dd1dd11 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -68,7 +68,7 @@ class EliminateBranches: def visit(self, node: PsAstNode, ec: BranchElimContext) -> PsAstNode: match node: case PsLoop(_, _, _, _, body): - ec.enclosing_loops.append(cast(PsLoop, node)) + ec.enclosing_loops.append(node) self.visit(body, ec) ec.enclosing_loops.pop() diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index b66efe4f2..3a07cb56f 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -36,6 +36,7 @@ from ..ast.expressions import ( ) from ..ast.vector import PsVecBroadcast from ..ast.util import AstEqWrapper +from ..exceptions import PsInternalCompilerError from ..constants import PsConstant from ..memory import PsSymbol @@ -138,13 +139,18 @@ class EliminateConstants: node = self.visit(node, ecc) if ecc.extractions: + if not isinstance(node, PsStructuralNode): + raise PsInternalCompilerError( + f"Cannot extract constant expressions from outermost node {node}" + ) + prepend_decls = [ PsDeclaration(PsExpression.make(symb), expr) for symb, expr in ecc.extractions ] if not isinstance(node, PsBlock): - node = PsBlock(prepend_decls + [cast(PsStructuralNode, node)]) + node = PsBlock(prepend_decls + [node]) else: node.children = prepend_decls + list(node.children) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 9637485dd..f7fe81ad7 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -91,7 +91,7 @@ class HoistLoopInvariantDeclarations: """Search the outermost loop and start the hoisting cascade there.""" match node: case PsLoop(): - temp_block = PsBlock([cast(PsLoop, node)]) + temp_block = PsBlock([node]) temp_block = cast(PsBlock, self.visit(temp_block)) if temp_block.statements == [node]: return node diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index 6b518a30d..e1e4fea50 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -213,7 +213,7 @@ class LoopVectorizer: trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr) trailing_loop_body = substitute_symbols( - loop.body._clone_structural(), {scalar_ctr: PsExpression.make(trailing_ctr)} + loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)} ) trailing_loop = PsLoop( PsExpression.make(trailing_ctr), diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py index 59241c295..8dff9e45e 100644 --- a/src/pystencils/backend/transformations/rewrite.py +++ b/src/pystencils/backend/transformations/rewrite.py @@ -2,7 +2,7 @@ from typing import overload from ..memory import PsSymbol from ..ast import PsAstNode -from ..ast.structural import PsBlock +from ..ast.structural import PsStructuralNode, PsBlock from ..ast.expressions import PsExpression, PsSymbolExpr @@ -18,6 +18,13 @@ def substitute_symbols( pass +@overload +def substitute_symbols( + node: PsStructuralNode, subs: dict[PsSymbol, PsExpression] +) -> PsStructuralNode: + pass + + @overload def substitute_symbols( node: PsAstNode, subs: dict[PsSymbol, PsExpression] -- GitLab