From d7e6890cff8f327ce285a3360821a4341b2acc46 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 20 Mar 2025 16:23:54 +0100
Subject: [PATCH] Fix typecheck

---
 src/pystencils/backend/platforms/generic_cpu.py    |  4 ++--
 src/pystencils/backend/platforms/generic_gpu.py    |  4 ++--
 src/pystencils/backend/platforms/platform.py       |  4 ++--
 src/pystencils/backend/platforms/sycl.py           |  4 ++--
 .../backend/transformations/loop_vectorizer.py     | 14 +++++++-------
 .../backend/transformations/select_functions.py    |  4 ++--
 6 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 2f873ff29..43b048184 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -22,7 +22,7 @@ from ..kernelcreation.iteration_space import (
 )
 
 from ..constants import PsConstant
-from ..ast.structural import PsDeclaration, PsLoop, PsBlock
+from ..ast.structural import PsDeclaration, PsLoop, PsBlock, PsStructuralNode
 from ..ast.expressions import (
     PsSymbolExpr,
     PsExpression,
@@ -60,7 +60,7 @@ class GenericCpu(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
         call_func = call.function
         assert isinstance(call_func, PsReductionFunction | PsMathFunction)
 
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index d3e8de42d..9b21457be 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -24,7 +24,7 @@ from ..kernelcreation import (
 )
 
 from ..kernelcreation.context import KernelCreationContext
-from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
+from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsLiteralExpr,
@@ -238,7 +238,7 @@ class GenericGpu(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
         call_func = call.function
         assert isinstance(call_func, PsReductionFunction | PsMathFunction)
 
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index 437962172..4f738dd5d 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -1,7 +1,7 @@
 from abc import ABC, abstractmethod
 
 from ..ast import PsAstNode
-from ..ast.structural import PsBlock
+from ..ast.structural import PsBlock, PsStructuralNode
 from ..ast.expressions import PsCall, PsExpression
 
 from ..kernelcreation.context import KernelCreationContext
@@ -38,7 +38,7 @@ class Platform(ABC):
     @abstractmethod
     def select_function(
         self, call: PsCall
-    ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
+    ) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
         """Select an implementation for the given function on the given data type.
 
         If no viable implementation exists, raise a `MaterializationError`.
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 7d7b8d1a7..78af01b2f 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -7,7 +7,7 @@ from ..kernelcreation.iteration_space import (
     FullIterationSpace,
     SparseIterationSpace,
 )
-from ..ast.structural import PsDeclaration, PsBlock, PsConditional
+from ..ast.structural import PsDeclaration, PsBlock, PsConditional, PsStructuralNode
 from ..ast.expressions import (
     PsExpression,
     PsSymbolExpr,
@@ -56,7 +56,7 @@ class SyclPlatform(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
         assert isinstance(call.function, PsMathFunction)
 
         func = call.function.func
diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py
index b78114553..a96c6af4b 100644
--- a/src/pystencils/backend/transformations/loop_vectorizer.py
+++ b/src/pystencils/backend/transformations/loop_vectorizer.py
@@ -7,7 +7,7 @@ from ...types import PsVectorType, PsScalarType
 from ..kernelcreation import KernelCreationContext
 from ..constants import PsConstant
 from ..ast import PsAstNode
-from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment
+from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment, PsStructuralNode
 from ..ast.expressions import PsExpression, PsTernary, PsGt, PsSymbolExpr
 from ..ast.vector import PsVecBroadcast, PsVecHorizontal
 from ..ast.analysis import collect_undefined_symbols
@@ -135,20 +135,20 @@ class LoopVectorizer:
         vc = VectorizationContext(self._ctx, self._lanes, axis)
 
         #   Prepare reductions
-        simd_init_local_reduction_vars = []
-        simd_writeback_local_reduction_vars = []
+        simd_init_local_reduction_vars: list[PsStructuralNode] = []
+        simd_writeback_local_reduction_vars: list[PsStructuralNode] = []
         for symb, reduction_info in self._ctx.symbols_reduction_info.items():
             # Vectorize symbol for local copy
             vector_symb = vc.vectorize_symbol(symb)
 
             # Declare and init vector
-            simd_init_local_reduction_vars += [self._type_fold(PsDeclaration(
-                PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb))))]
+            simd_init_local_reduction_vars += [PsDeclaration(
+                PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb)))]
 
             # Write back vectorization result
-            simd_writeback_local_reduction_vars += [self._type_fold(PsAssignment(
+            simd_writeback_local_reduction_vars += [PsAssignment(
                 PsSymbolExpr(symb), PsVecHorizontal(self._lanes, PsSymbolExpr(symb), PsSymbolExpr(vector_symb),
-                                                    reduction_info.op)))]
+                                                    reduction_info.op))]
 
         #   Generate vectorized loop body
         simd_body = self._vectorize_ast(loop.body, vc)
diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py
index d5f731653..576cebad1 100644
--- a/src/pystencils/backend/transformations/select_functions.py
+++ b/src/pystencils/backend/transformations/select_functions.py
@@ -1,4 +1,4 @@
-from ..ast.structural import PsAssignment, PsBlock
+from ..ast.structural import PsAssignment, PsBlock, PsStructuralNode
 from ..exceptions import MaterializationError
 from ..platforms import Platform
 from ..ast import PsAstNode
@@ -31,7 +31,7 @@ class SelectFunctions:
                         match new_rhs:
                             case PsExpression():
                                 return PsBlock(prepend + (PsAssignment(node.lhs, new_rhs),))
-                            case PsAstNode():
+                            case PsStructuralNode():
                                 # special case: produces structural with atomic operation writing value back to ptr
                                 return PsBlock(prepend + (new_rhs,))
                             case _:
-- 
GitLab