From a8479afadc2a95e2ecda861ce7fd186375477347 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Fri, 21 Feb 2025 18:11:51 +0100
Subject: [PATCH] Refactor reductionfunction mechanism

---
 src/pystencils/backend/functions.py           |   1 -
 .../backend/kernelcreation/typification.py    |   2 +-
 src/pystencils/backend/platforms/cuda.py      | 142 ++++++++----------
 .../backend/platforms/generic_cpu.py          |  49 +++---
 src/pystencils/backend/platforms/platform.py  |  12 +-
 src/pystencils/backend/platforms/sycl.py      |   9 +-
 .../transformations/select_functions.py       |  36 ++++-
 src/pystencils/codegen/driver.py              |  16 +-
 8 files changed, 119 insertions(+), 148 deletions(-)

diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index d28ef5f44..4e38de5e9 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -142,7 +142,6 @@ class ReductionFunctions(Enum):
     Each platform has to materialize these functions to a concrete implementation.
     """
 
-    InitLocalCopy = ("InitLocalCopy", 2)
     WriteBackToPtr = ("WriteBackToPtr", 2)
 
     def __init__(self, func_name, num_args):
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 544746ef6..284e80b9d 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -617,7 +617,7 @@ class Typifier:
 
             case PsCall(function, args):
                 match function:
-                    case PsMathFunction() | PsReductionFunction():
+                    case PsMathFunction():
                         for arg in args:
                             self.visit_expr(arg, tc)
                         tc.infer_dtype(expr)
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 8936bf73f..1f6506c8f 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -212,10 +212,66 @@ class CudaPlatform(GenericGpu):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression:
-        assert isinstance(call.function, PsMathFunction)
-
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
         func = call.function.func
+
+        if func in ReductionFunctions:
+            match func:
+                case ReductionFunctions.WriteBackToPtr:
+                    ptr_expr, symbol_expr = call.args
+                    op = call.function.reduction_op
+                    stype = symbol_expr.dtype
+                    ptrtype = ptr_expr.dtype
+
+                    warp_size = 32   # TODO: get from platform/user config
+
+                    assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
+                    assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
+
+                    if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
+                        NotImplementedError("atomic operations are only available for float32/64 datatypes")
+
+                    # workaround for subtractions -> use additions for reducing intermediate results
+                    # similar to OpenMP reductions: local copies (negative sign) are added at the end
+                    match op:
+                        case ReductionOp.Sub:
+                            actual_op = ReductionOp.Add
+                        case _:
+                            actual_op = op
+
+                    # perform local warp reductions
+                    def gen_shuffle_instr(offset: int):
+                        full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
+                        return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
+                                      [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
+
+                    num_shuffles = math.frexp(warp_size)[1]
+                    shuffles = [PsAssignment(symbol_expr,
+                                             compound_op_to_expr(actual_op,
+                                                                 symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
+                                for i in reversed(range(1, num_shuffles))]
+
+                    # find first thread in warp
+                    ispace = self._ctx.get_iteration_space()
+                    is_valid_thread = self._get_condition_for_translation(ispace)
+                    thread_indices_per_dim = [
+                        idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
+                        for i, idx in enumerate(THREAD_IDX[:ispace.rank])
+                    ]
+                    tid: PsExpression = thread_indices_per_dim[0]
+                    for t in thread_indices_per_dim[1:]:
+                        tid = PsAdd(tid, t)
+                    first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
+                                                PsConstantExpr(PsConstant(0, SInt(32))))
+                    cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
+
+                    # use atomic operation on first thread of warp to sync
+                    call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
+                    call.args = (ptr_expr, symbol_expr)
+
+                    # assemble warp reduction
+                    return (shuffles, PsConditional(cond, PsBlock([PsStatement(call)])))
+
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
 
@@ -232,7 +288,7 @@ class CudaPlatform(GenericGpu):
 
             return PsLiteralExpr(PsLiteral(define, dtype))
 
-        if isinstance(dtype, PsIeeeFloatType):
+        if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions:
             match func:
                 case (
                     MathFunctions.Exp
@@ -285,84 +341,6 @@ class CudaPlatform(GenericGpu):
             f"No implementation available for function {func} on data type {dtype}"
         )
 
-    def unfold_function(
-        self, call: PsCall
-    ) -> PsAstNode:
-        assert isinstance(call.function, PsReductionFunction)
-
-        func = call.function.func
-
-        match func:
-            case ReductionFunctions.InitLocalCopy:
-                symbol_expr, init_val = call.args
-                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression)
-
-                return PsDeclaration(symbol_expr, init_val)
-            case ReductionFunctions.WriteBackToPtr:
-                ptr_expr, symbol_expr = call.args
-                op = call.function.reduction_op
-                stype = symbol_expr.dtype
-                ptrtype = ptr_expr.dtype
-
-                warp_size = 32   # TODO: get from platform/user config
-
-                assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
-                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
-
-                if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
-                    NotImplementedError("atomic operations are only available for float32/64 datatypes")
-
-                # set up mask symbol for active threads in warp
-                #mask = PsSymbol("__shfl_mask", UInt(32))
-                #self._ctx.add_symbol(mask)
-                full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
-
-                # workaround for subtractions -> use additions for reducing intermediate results
-                # similar to OpenMP reductions: local copies (negative sign) are added at the end
-                match op:
-                    case ReductionOp.Sub:
-                        actual_op = ReductionOp.Add
-                    case _:
-                        actual_op = op
-
-                # perform local warp reductions
-                def gen_shuffle_instr(offset: int):
-                    return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
-                                  [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
-
-                num_shuffles = math.frexp(warp_size)[1]
-                shuffles = [PsAssignment(symbol_expr,
-                                         compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
-                            for i in reversed(range(1, num_shuffles))]
-
-                # find first thread in warp
-                ispace = self._ctx.get_iteration_space()  # TODO: receive as argument in unfold_function?
-                is_valid_thread = self._get_condition_for_translation(ispace)
-                thread_indices_per_dim = [
-                    idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
-                    for i, idx in enumerate(THREAD_IDX[:ispace.rank])
-                ]
-                tid: PsExpression = thread_indices_per_dim[0]
-                for t in thread_indices_per_dim[1:]:
-                    tid = PsAdd(tid, t)
-                first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
-                                            PsConstantExpr(PsConstant(0, SInt(32))))
-                cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
-
-                #ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)),
-                #                                 [full_mask, is_valid_thread])
-                #decl_mask = PsDeclaration(full_mask)
-
-                # use atomic operation on first thread of warp to sync
-                call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
-                call.args = (ptr_expr, symbol_expr)
-
-                # assemble warp reduction
-                return PsBlock(
-                    #[decl_mask]
-                    shuffles
-                    + [PsConditional(cond, PsBlock([PsStatement(call)]))])
-
     #   Internals
 
     # TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU?
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index 1e7468e33..24692b25c 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -4,8 +4,7 @@ from typing import Sequence
 from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr
 
 from ..ast import PsAstNode
-from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, \
-    PsReductionFunction
+from ..functions import CFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions
 from ..literals import PsLiteral
 from ...compound_op_mapping import compound_op_to_expr
 from ...sympyextensions import ReductionOp
@@ -60,43 +59,31 @@ class GenericCpu(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def unfold_function(
-        self, call: PsCall
-    ) -> PsAstNode:
-        assert isinstance(call.function, PsReductionFunction)
-
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
         func = call.function.func
 
-        match func:
-            case ReductionFunctions.InitLocalCopy:
-                symbol_expr, init_val = call.args
-                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression)
-
-                return PsDeclaration(symbol_expr, init_val)
-            case ReductionFunctions.WriteBackToPtr:
-                ptr_expr, symbol_expr = call.args
-                op = call.function.reduction_op
-
-                assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
-                assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
+        if func in ReductionFunctions:
+            match func:
+                case ReductionFunctions.WriteBackToPtr:
+                    ptr_expr, symbol_expr = call.args
+                    op = call.function.reduction_op
 
-                ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
+                    assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
+                    assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
 
-                # inspired by OpenMP: local reduction variable (negative sign) is added at the end
-                actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
+                    ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
 
-                # TODO: can this be avoided somehow?
-                potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
-                if isinstance(potential_call, PsCall):
-                    potential_call.dtype = symbol_expr.dtype
-                    potential_call = self.select_function(potential_call)
+                    # inspired by OpenMP: local reduction variable (negative sign) is added at the end
+                    actual_op = ReductionOp.Add if op is ReductionOp.Sub else op
 
-                return PsAssignment(ptr_access, potential_call)
+                    # TODO: can this be avoided somehow?
+                    potential_call = compound_op_to_expr(actual_op, ptr_access, symbol_expr)
+                    if isinstance(potential_call, PsCall):
+                        potential_call.dtype = symbol_expr.dtype
+                        potential_call = self.select_function(potential_call)
 
-    def select_function(self, call: PsCall) -> PsExpression:
-        assert isinstance(call.function, PsMathFunction)
+                    return potential_call
 
-        func = call.function.func
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
 
diff --git a/src/pystencils/backend/platforms/platform.py b/src/pystencils/backend/platforms/platform.py
index e195d59bc..90fd69084 100644
--- a/src/pystencils/backend/platforms/platform.py
+++ b/src/pystencils/backend/platforms/platform.py
@@ -38,19 +38,9 @@ class Platform(ABC):
     @abstractmethod
     def select_function(
         self, call: PsCall
-    ) -> PsExpression:
+    ) -> PsExpression | tuple[tuple[PsAstNode, ...], PsExpression]:
         """Select an implementation for the given function on the given data type.
 
         If no viable implementation exists, raise a `MaterializationError`.
         """
         pass
-
-    @abstractmethod
-    def unfold_function(
-        self, call: PsCall
-    ) -> PsAstNode:
-        """Unfolds an implementation for the given function on the given data type.
-
-        If no viable implementation exists, raise a `MaterializationError`.
-        """
-        pass
diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py
index 9b077fd2b..eae2b7598 100644
--- a/src/pystencils/backend/platforms/sycl.py
+++ b/src/pystencils/backend/platforms/sycl.py
@@ -57,7 +57,7 @@ class SyclPlatform(GenericGpu):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression:
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode], PsExpression]:
         assert isinstance(call.function, PsMathFunction)
 
         func = call.function.func
@@ -108,13 +108,6 @@ class SyclPlatform(GenericGpu):
             f"No implementation available for function {func} on data type {dtype}"
         )
 
-    def unfold_function(
-        self, call: PsCall
-    ) -> PsAstNode:
-        raise MaterializationError(
-            f"No implementation available for function {call.function.name}"
-        )
-
     def _prepend_dense_translation(
         self, body: PsBlock, ispace: FullIterationSpace
     ) -> PsBlock:
diff --git a/src/pystencils/backend/transformations/select_functions.py b/src/pystencils/backend/transformations/select_functions.py
index 0045de87b..288650698 100644
--- a/src/pystencils/backend/transformations/select_functions.py
+++ b/src/pystencils/backend/transformations/select_functions.py
@@ -1,6 +1,8 @@
-from ..platforms import Platform
+from ..ast.structural import PsStatement, PsAssignment, PsBlock
+from ..exceptions import MaterializationError
+from ..platforms import Platform, CudaPlatform
 from ..ast import PsAstNode
-from ..ast.expressions import PsCall
+from ..ast.expressions import PsCall, PsExpression
 from ..functions import PsMathFunction, PsReductionFunction
 
 
@@ -17,9 +19,31 @@ class SelectFunctions:
     def visit(self, node: PsAstNode) -> PsAstNode:
         node.children = [self.visit(c) for c in node.children]
 
-        if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction):
-            return self._platform.select_function(node)
-        elif isinstance(node, PsCall) and isinstance(node.function, PsReductionFunction):
-            return self._platform.unfold_function(node)
+        if isinstance(node, PsAssignment):
+            rhs = node.rhs
+            if isinstance(rhs, PsCall) and isinstance(rhs.function, PsReductionFunction):
+                resolved_func = self._platform.select_function(rhs)
+
+                match resolved_func:
+                    case ((prepend), expr):
+                        match self._platform:
+                            case CudaPlatform():
+                                # special case: produces statement with atomic operation writing value back to ptr
+                                return PsBlock(prepend + [PsStatement(expr)])
+                            case _:
+                                return PsBlock(prepend + [PsAssignment(node.lhs, expr)])
+                    case PsExpression():
+                        return PsAssignment(node.lhs, resolved_func)
+                    case _:
+                        raise MaterializationError(
+                            f"Wrong return type for resolved function {rhs.function.name} in SelectFunctions."
+                        )
+            else:
+                return node
+        elif isinstance(node, PsCall) and isinstance(node.function, PsMathFunction):
+            resolved_func = self._platform.select_function(node)
+            assert isinstance(resolved_func, PsExpression)
+
+            return resolved_func
         else:
             return node
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 96e9b94ed..9f04d074a 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -16,7 +16,7 @@ from .kernel import Kernel, GpuKernel
 from .properties import PsSymbolProperty, FieldBasePtr
 from .parameters import Parameter
 from ..backend.functions import PsReductionFunction, ReductionFunctions
-from ..backend.ast.expressions import PsSymbolExpr, PsCall
+from ..backend.ast.expressions import PsSymbolExpr, PsCall, PsMemAcc, PsConstantExpr
 from .gpu_indexing import GpuIndexing, GpuLaunchConfiguration
 
 from ..field import Field
@@ -24,7 +24,7 @@ from ..types import PsIntegerType, PsScalarType
 
 from ..backend.memory import PsSymbol
 from ..backend.ast import PsAstNode
-from ..backend.ast.structural import PsBlock, PsLoop
+from ..backend.ast.structural import PsBlock, PsLoop, PsDeclaration, PsAssignment
 from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
 from ..backend.kernelcreation import (
     KernelCreationContext,
@@ -187,16 +187,16 @@ class DefaultKernelCreationDriver:
             symbol_expr = typify(PsSymbolExpr(symbol))
             ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
             init_val = typify(reduction_info.init_val)
+            ptr_access = PsMemAcc(ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
 
-            init_local_copy = PsCall(PsReductionFunction(ReductionFunctions.InitLocalCopy, reduction_info.op),
-                                     [symbol_expr, init_val])
+            decl_local_copy = PsDeclaration(symbol_expr, init_val)
             write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op),
                                     [ptr_symbol_expr, symbol_expr])
 
-            # Init local reduction variable copy
-            kernel_ast.statements = [init_local_copy] + kernel_ast.statements
-            # Write back result to reduction target variable
-            kernel_ast.statements += [write_back_ptr]
+            prepend_ast = [decl_local_copy]                          # declare and init local copy with neutral element
+            append_ast = [PsAssignment(ptr_access, write_back_ptr)]  # write back result to reduction target variable
+
+            kernel_ast.statements = prepend_ast + kernel_ast.statements + append_ast
 
         #   Target-Specific optimizations
         if self._target.is_cpu():
-- 
GitLab