From 02be4d5e994e458e0e42d72045ec53355e7820d1 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Fri, 21 Feb 2025 19:35:48 +0100
Subject: [PATCH] Fix typecheck

---
 src/pystencils/backend/platforms/cuda.py      | 121 +++++++++---------
 .../backend/platforms/generic_cpu.py          |  40 +++---
 src/pystencils/backend/platforms/platform.py  |   2 +-
 .../transformations/select_functions.py       |  20 +--
 src/pystencils/codegen/driver.py              |   8 +-
 5 files changed, 99 insertions(+), 92 deletions(-)

diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 6df502c1f..291858810 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -36,8 +36,8 @@ from ..ast.expressions import (
 from ..ast.expressions import PsLt, PsAnd
 from ...types import PsSignedIntegerType, PsIeeeFloatType
 from ..literals import PsLiteral
-from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions
-
+from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \
+    PsMathFunction
 
 int32 = PsSignedIntegerType(width=32, const=False)
 
@@ -209,65 +209,66 @@ class CudaPlatform(GenericGpu):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
-        func = call.function.func
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
+        call_func = call.function
+        assert isinstance(call_func, PsReductionFunction | PsMathFunction)
 
-        if func in ReductionFunctions:
-            match func:
-                case ReductionFunctions.WriteBackToPtr:
-                    ptr_expr, symbol_expr = call.args
-                    op = call.function.reduction_op
-                    stype = symbol_expr.dtype
-                    ptrtype = ptr_expr.dtype
-
-                    warp_size = 32   # TODO: get from platform/user config
-
-                    assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
-                    assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
-
-                    if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
-                        NotImplementedError("atomic operations are only available for float32/64 datatypes")
-
-                    # workaround for subtractions -> use additions for reducing intermediate results
-                    # similar to OpenMP reductions: local copies (negative sign) are added at the end
-                    match op:
-                        case ReductionOp.Sub:
-                            actual_op = ReductionOp.Add
-                        case _:
-                            actual_op = op
-
-                    # perform local warp reductions
-                    def gen_shuffle_instr(offset: int):
-                        full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
-                        return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
-                                      [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
-
-                    num_shuffles = math.frexp(warp_size)[1]
-                    shuffles = [PsAssignment(symbol_expr,
-                                             compound_op_to_expr(actual_op,
-                                                                 symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
-                                for i in reversed(range(1, num_shuffles))]
-
-                    # find first thread in warp
-                    ispace = self._ctx.get_iteration_space()
-                    is_valid_thread = self._get_condition_for_translation(ispace)
-                    thread_indices_per_dim = [
-                        idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
-                        for i, idx in enumerate(THREAD_IDX[:ispace.rank])
-                    ]
-                    tid: PsExpression = thread_indices_per_dim[0]
-                    for t in thread_indices_per_dim[1:]:
-                        tid = PsAdd(tid, t)
-                    first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
-                                                PsConstantExpr(PsConstant(0, SInt(32))))
-                    cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
-
-                    # use atomic operation on first thread of warp to sync
-                    call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
-                    call.args = (ptr_expr, symbol_expr)
-
-                    # assemble warp reduction
-                    return (shuffles, PsConditional(cond, PsBlock([PsStatement(call)])))
+        func = call_func.func
+
+        if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
+            ptr_expr, symbol_expr = call.args
+            op = call_func.reduction_op
+            stype = symbol_expr.dtype
+            ptrtype = ptr_expr.dtype
+
+            warp_size = 32  # TODO: get from platform/user config
+
+            assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
+            assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
+
+            if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
+                NotImplementedError("atomic operations are only available for float32/64 datatypes")
+
+            # workaround for subtractions -> use additions for reducing intermediate results
+            # similar to OpenMP reductions: local copies (negative sign) are added at the end
+            match op:
+                case ReductionOp.Sub:
+                    actual_op = ReductionOp.Add
+                case _:
+                    actual_op = op
+
+            # perform local warp reductions
+            def gen_shuffle_instr(offset: int):
+                full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
+                return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
+                              [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
+
+            num_shuffles = math.frexp(warp_size)[1]
+            shuffles = tuple(PsAssignment(symbol_expr,
+                                          compound_op_to_expr(actual_op,
+                                                              symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
+                             for i in reversed(range(1, num_shuffles)))
+
+            # find first thread in warp
+            ispace = self._ctx.get_iteration_space()
+            is_valid_thread = self._get_condition_for_translation(ispace)
+            thread_indices_per_dim = [
+                idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
+                for i, idx in enumerate(THREAD_IDX[:ispace.rank])
+            ]
+            tid: PsExpression = thread_indices_per_dim[0]
+            for t in thread_indices_per_dim[1:]:
+                tid = PsAdd(tid, t)
+            first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
+                                        PsConstantExpr(PsConstant(0, SInt(32))))
+            cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
+
+            # use atomic operation on first thread of warp to sync
+            call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
+            call.args = (ptr_expr, symbol_expr)
+
+            # assemble warp reduction
+            return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
 
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 3ffdfa22f..2f873ff29 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -4,7 +4,8 @@ from typing import Sequence
 from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr
 
 from ..ast import PsAstNode
-from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions
+from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, PsMathFunction, \
+    PsReductionFunction
 from ..literals import PsLiteral
 from ...compound_op_mapping import compound_op_to_expr
 from ...sympyextensions import ReductionOp
@@ -59,30 +60,31 @@ class GenericCpu(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
-        func = call.function.func
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
+        call_func = call.function
+        assert isinstance(call_func, PsReductionFunction | PsMathFunction)
 
-        if func in ReductionFunctions:
-            match func:
-                case ReductionFunctions.WriteBackToPtr:
-                    ptr_expr, symbol_expr = call.args
-                    op = call.function.reduction_op
+        func = call_func.func
+
+        if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
+            ptr_expr, symbol_expr = call.args
+            op = call_func.reduction_op
 
-                    assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
-                    assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
+            assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
+            assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
 
-                    ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
+            ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
 
-                    # inspired by OpenMP: local reduction variable (negative sign) is added at the end
-                    actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
+            # inspired by OpenMP: local reduction variable (negative sign) is added at the end
+            actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
 
-                    # TODO: can this be avoided somehow?
-                    potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
-                    if isinstance(potential_call, PsCall):
-                        potential_call.dtype = symbol_expr.dtype
-                        potential_call = self.select_function(potential_call)
+            # create binop and potentially select corresponding function for e.g. min or max
+            potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
+            if isinstance(potential_call, PsCall):
+                potential_call.dtype = symbol_expr.dtype
+                return self.select_function(potential_call)
 
-                    return potential_call
+            return potential_call
 
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index 90fd69084..437962172 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -38,7 +38,7 @@ class Platform(ABC):
     @abstractmethod
     def select_function(
         self, call: PsCall
-    ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsExpression]:
+    ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
         """Select an implementation for the given function on the given data type.
 
         If no viable implementation exists, raise a `MaterializationError`.
diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py
index 288650698..d5f731653 100644
--- a/src/pystencils/backend/transformations/select_functions.py
+++ b/src/pystencils/backend/transformations/select_functions.py
@@ -1,6 +1,6 @@
-from ..ast.structural import PsStatement, PsAssignment, PsBlock
+from ..ast.structural import PsAssignment, PsBlock
 from ..exceptions import MaterializationError
-from ..platforms import Platform, CudaPlatform
+from ..platforms import Platform
 from ..ast import PsAstNode
 from ..ast.expressions import PsCall, PsExpression
 from ..functions import PsMathFunction, PsReductionFunction
@@ -25,13 +25,17 @@ class SelectFunctions:
                 resolved_func = self._platform.select_function(rhs)
 
                 match resolved_func:
-                    case ((prepend), expr):
-                        match self._platform:
-                            case CudaPlatform():
-                                # special case: produces statement with atomic operation writing value back to ptr
-                                return PsBlock(prepend + [PsStatement(expr)])
+                    case (prepend, new_rhs):
+                        assert isinstance(prepend, tuple)
+
+                        match new_rhs:
+                            case PsExpression():
+                                return PsBlock(prepend + (PsAssignment(node.lhs, new_rhs),))
+                            case PsAstNode():
+                                # special case: produces structural with atomic operation writing value back to ptr
+                                return PsBlock(prepend + (new_rhs,))
                             case _:
-                                return PsBlock(prepend + [PsAssignment(node.lhs, expr)])
+                                assert False, "Unexpected output from SelectFunctions."
                     case PsExpression():
                         return PsAssignment(node.lhs, resolved_func)
                     case _:
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 9f04d074a..cc3411249 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -187,16 +187,16 @@ class DefaultKernelCreationDriver:
             symbol_expr = typify(PsSymbolExpr(symbol))
             ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
             init_val = typify(reduction_info.init_val)
-            ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
 
-            decl_local_copy = PsDeclaration(symbol_expr, init_val)
+            ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
             write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op),
                                     [ptr_symbol_expr, symbol_expr])
 
-            prepend_ast = [decl_local_copy]                          # declare and init local copy with neutral element
+            prepend_ast = [PsDeclaration(symbol_expr, init_val)]     # declare and init local copy with neutral element
             append_ast = [PsAssignment(ptr_access, write_back_ptr)]  # write back result to reduction target variable
 
-            kernel_ast.statements = prepend_ast + kernel_ast.statements + append_ast
+            kernel_ast.statements = prepend_ast + kernel_ast.statements
+            kernel_ast.statements += append_ast
 
         #   Target-Specific optimizations
         if self._target.is_cpu():
-- 
GitLab