diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 2f873ff29052adcf27eeed4a08741097a3b52d3b..43b048184144c4c5403bbb8ca7aa1daab089da83 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -22,7 +22,7 @@ from ..kernelcreation.iteration_space import ( ) from ..constants import PsConstant -from ..ast.structural import PsDeclaration, PsLoop, PsBlock +from ..ast.structural import PsDeclaration, PsLoop, PsBlock, PsStructuralNode from ..ast.expressions import ( PsSymbolExpr, PsExpression, @@ -60,7 +60,7 @@ class GenericCpu(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function assert isinstance(call_func, PsReductionFunction | PsMathFunction) diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index d3e8de42db572c7cc2cdfa807f777ccd4bd76a89..9b21457be54c8871d95c466592efe68565634a36 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -24,7 +24,7 @@ from ..kernelcreation import ( ) from ..kernelcreation.context import KernelCreationContext -from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment +from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment, PsStructuralNode from ..ast.expressions import ( PsExpression, PsLiteralExpr, @@ -238,7 +238,7 @@ class GenericGpu(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: call_func = call.function assert isinstance(call_func, PsReductionFunction | PsMathFunction) diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py index 437962172bd35c6e131d659c34b83293c64c0f5e..4f738dd5ddd67c831d6fd80f0c59be3915553058 100644 --- a/src/pystencils/backend/platforms/platform.py +++ b/src/pystencils/backend/platforms/platform.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from ..ast import PsAstNode -from ..ast.structural import PsBlock +from ..ast.structural import PsBlock, PsStructuralNode from ..ast.expressions import PsCall, PsExpression from ..kernelcreation.context import KernelCreationContext @@ -38,7 +38,7 @@ class Platform(ABC): @abstractmethod def select_function( self, call: PsCall - ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: + ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: """Select an implementation for the given function on the given data type. If no viable implementation exists, raise a `MaterializationError`. diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 7d7b8d1a754ffc4bd2235c399f644dd85d6e5664..78af01b2f3c63b38e6c53c77686c8bff76f7d40c 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -7,7 +7,7 @@ from ..kernelcreation.iteration_space import ( FullIterationSpace, SparseIterationSpace, ) -from ..ast.structural import PsDeclaration, PsBlock, PsConditional +from ..ast.structural import PsDeclaration, PsBlock, PsConditional, PsStructuralNode from ..ast.expressions import ( PsExpression, PsSymbolExpr, @@ -56,7 +56,7 @@ class SyclPlatform(Platform): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") - def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]: + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]: assert isinstance(call.function, PsMathFunction) func = call.function.func diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index b78114553464f6d04c59f3d2f4a3e65ec950b75f..a96c6af4b14e00f19098fbb4113b42c800ec71c2 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -7,7 +7,7 @@ from ...types import PsVectorType, PsScalarType from ..kernelcreation import KernelCreationContext from ..constants import PsConstant from ..ast import PsAstNode -from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment +from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment, PsStructuralNode from ..ast.expressions import PsExpression, PsTernary, PsGt, PsSymbolExpr from ..ast.vector import PsVecBroadcast, PsVecHorizontal from ..ast.analysis import collect_undefined_symbols @@ -135,20 +135,20 @@ class LoopVectorizer: vc = VectorizationContext(self._ctx, self._lanes, axis) # Prepare reductions - simd_init_local_reduction_vars = [] - simd_writeback_local_reduction_vars = [] + simd_init_local_reduction_vars: list[PsStructuralNode] = [] + simd_writeback_local_reduction_vars: list[PsStructuralNode] = [] for symb, reduction_info in self._ctx.symbols_reduction_info.items(): # Vectorize symbol for local copy vector_symb = vc.vectorize_symbol(symb) # Declare and init vector - simd_init_local_reduction_vars += [self._type_fold(PsDeclaration( - PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb))))] + simd_init_local_reduction_vars += [PsDeclaration( + PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb)))] # Write back vectorization result - simd_writeback_local_reduction_vars += [self._type_fold(PsAssignment( + simd_writeback_local_reduction_vars += [PsAssignment( PsSymbolExpr(symb), PsVecHorizontal(self._lanes, PsSymbolExpr(symb), PsSymbolExpr(vector_symb), - reduction_info.op)))] + reduction_info.op))] # Generate vectorized loop body simd_body = self._vectorize_ast(loop.body, vc) diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py index d5f73165357efa1b996b71449489e9e0f57554cc..576cebad1ed9b842c58a521fbd4b3329af8b3a69 100644 --- a/src/pystencils/backend/transformations/select_functions.py +++ b/src/pystencils/backend/transformations/select_functions.py @@ -1,4 +1,4 @@ -from ..ast.structural import PsAssignment, PsBlock +from ..ast.structural import PsAssignment, PsBlock, PsStructuralNode from ..exceptions import MaterializationError from ..platforms import Platform from ..ast import PsAstNode @@ -31,7 +31,7 @@ class SelectFunctions: match new_rhs: case PsExpression(): return PsBlock(prepend + (PsAssignment(node.lhs, new_rhs),)) - case PsAstNode(): + case PsStructuralNode(): # special case: produces structural with atomic operation writing value back to ptr return PsBlock(prepend + (new_rhs,)) case _: