From 90837d04eaad5b3ad049c11fa5af48cf0942e812 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 20 Mar 2025 15:36:01 +0100
Subject: [PATCH] Merge handling for GPU reductions into generic_gpu.py for the
 time being

---
 .../backend/platforms/generic_gpu.py          | 147 +++++++++++++++---
 src/pystencils/codegen/driver.py              |   5 +
 2 files changed, 130 insertions(+), 22 deletions(-)

diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 11425d923..f16c28e8c 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -1,7 +1,16 @@
 from __future__ import annotations
-from abc import ABC, abstractmethod
 
-from ...types import constify, deconstify
+import math
+import operator
+from abc import ABC, abstractmethod
+from functools import reduce
+
+from ..ast import PsAstNode
+from ..constants import PsConstant
+from ...compound_op_mapping import compound_op_to_expr
+from ...sympyextensions.reduction import ReductionOp
+from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
+from ...types.quick import UInt, SInt
 from ..exceptions import MaterializationError
 from .platform import Platform
 
@@ -15,7 +24,7 @@ from ..kernelcreation import (
 )
 
 from ..kernelcreation.context import KernelCreationContext
-from ..ast.structural import PsBlock, PsConditional, PsDeclaration
+from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
 from ..ast.expressions import (
     PsExpression,
     PsLiteralExpr,
@@ -23,12 +32,17 @@ from ..ast.expressions import (
     PsCall,
     PsLookup,
     PsBufferAcc,
+    PsSymbolExpr,
+    PsConstantExpr,
+    PsAdd,
+    PsRem,
+    PsEq
 )
 from ..ast.expressions import PsLt, PsAnd
 from ...types import PsSignedIntegerType, PsIeeeFloatType
 from ..literals import PsLiteral
-from ..functions import PsMathFunction, MathFunctions, CFunction
-
+from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \
+    PsMathFunction
 
 int32 = PsSignedIntegerType(width=32, const=False)
 
@@ -174,10 +188,15 @@ class GenericGpu(Platform):
     def __init__(
         self,
         ctx: KernelCreationContext,
+        assume_warp_aligned_block_size: bool,
+        warp_size: int | None,
         thread_mapping: ThreadMapping | None = None,
     ) -> None:
         super().__init__(ctx)
 
+        self._assume_warp_aligned_block_size = assume_warp_aligned_block_size
+        self._warp_size = warp_size
+
         self._thread_mapping = (
             thread_mapping if thread_mapping is not None else Linear3DMapping()
         )
@@ -194,14 +213,107 @@ class GenericGpu(Platform):
         else:
             raise MaterializationError(f"Unknown type of iteration space: {ispace}")
 
-    def select_function(self, call: PsCall) -> PsExpression:
-        assert isinstance(call.function, PsMathFunction)
+    @staticmethod
+    def _get_condition_for_translation(ispace: IterationSpace):
+
+        if isinstance(ispace, FullIterationSpace):
+            conds = []
+
+            dimensions = ispace.dimensions_in_loop_order()
+
+            for dim in dimensions:
+                ctr_expr = PsExpression.make(dim.counter)
+                conds.append(PsLt(ctr_expr, dim.stop))
+
+            condition: PsExpression = conds[0]
+            for cond in conds[1:]:
+                condition = PsAnd(condition, cond)
+
+            return condition
+        elif isinstance(ispace, SparseIterationSpace):
+            sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
+            stop = PsExpression.make(ispace.index_list.shape[0])
+
+            return PsLt(sparse_ctr_expr.clone(), stop)
+        else:
+            raise MaterializationError(f"Unknown type of iteration space: {ispace}")
+
+    def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
+        call_func = call.function
+        assert isinstance(call_func, PsReductionFunction | PsMathFunction)
+
+        func = call_func.func
+
+        if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
+            ptr_expr, symbol_expr = call.args
+            op = call_func.reduction_op
+            stype = symbol_expr.dtype
+            ptrtype = ptr_expr.dtype
+
+            assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
+            assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
+
+            if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
+                NotImplementedError("atomic operations are only available for float32/64 datatypes")
+
+            # workaround for subtractions -> use additions for reducing intermediate results
+            # similar to OpenMP reductions: local copies (negative sign) are added at the end
+            match op:
+                case ReductionOp.Sub:
+                    actual_op = ReductionOp.Add
+                case _:
+                    actual_op = op
+
+            # perform local warp reductions
+            def gen_shuffle_instr(offset: int):
+                full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
+                return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
+                              [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
+
+            num_shuffles = math.frexp(self._warp_size)[1]
+            shuffles = tuple(PsAssignment(symbol_expr,
+                                          compound_op_to_expr(actual_op,
+                                                              symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
+                             for i in reversed(range(1, num_shuffles)))
+
+            # find first thread in warp
+            ispace = self._ctx.get_iteration_space()
+            is_valid_thread = self._get_condition_for_translation(ispace)
+            thread_indices_per_dim = [
+                idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
+                for i, idx in enumerate(THREAD_IDX[:ispace.rank])
+            ]
+            tid: PsExpression = thread_indices_per_dim[0]
+            for t in thread_indices_per_dim[1:]:
+                tid = PsAdd(tid, t)
+            first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
+                                        PsConstantExpr(PsConstant(0, SInt(32))))
+            cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
+
+            # use atomic operation on first thread of warp to sync
+            call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
+            call.args = (ptr_expr, symbol_expr)
+
+            # assemble warp reduction
+            return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
 
-        func = call.function.func
         dtype = call.get_dtype()
         arg_types = (dtype,) * func.num_args
 
-        if isinstance(dtype, PsIeeeFloatType):
+        if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
+            assert isinstance(dtype, PsIeeeFloatType)
+
+            match func:
+                case NumericLimitsFunctions.Min:
+                    define = "NEG_INFINITY"
+                case NumericLimitsFunctions.Max:
+                    define = "POS_INFINITY"
+                case _:
+                    raise MaterializationError(f"Cannot materialize call to function {func}")
+
+            return PsLiteralExpr(PsLiteral(define, dtype))
+
+        if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions:
             match func:
                 case (
                     MathFunctions.Exp
@@ -262,7 +374,6 @@ class GenericGpu(Platform):
         ctr_mapping = self._thread_mapping(ispace)
 
         indexing_decls = []
-        conds = []
 
         dimensions = ispace.dimensions_in_loop_order()
 
@@ -276,14 +387,9 @@ class GenericGpu(Platform):
             indexing_decls.append(
                 self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
             )
-            conds.append(PsLt(ctr_expr, dim.stop))
-
-        condition: PsExpression = conds[0]
-        for cond in conds[1:]:
-            condition = PsAnd(condition, cond)
-        ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
 
-        return ast
+        cond = self._get_condition_for_translation(ispace)
+        return PsBlock(indexing_decls + [PsConditional(cond, body)])
 
     def _prepend_sparse_translation(
         self, body: PsBlock, ispace: SparseIterationSpace
@@ -313,8 +419,5 @@ class GenericGpu(Platform):
         ]
         body.statements = mappings + body.statements
 
-        stop = PsExpression.make(ispace.index_list.shape[0])
-        condition = PsLt(sparse_ctr_expr.clone(), stop)
-        ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
-
-        return ast
+        cond = self._get_condition_for_translation(ispace)
+        return PsBlock([sparse_idx_decl, PsConditional(cond, body)])
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index c2bee0ad2..3962c316b 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -475,6 +475,9 @@ class DefaultKernelCreationDriver:
                 else None
             )
 
+            assume_warp_aligned_block_size: bool = self._cfg.gpu.get_option("assume_warp_aligned_block_size")
+            warp_size: int | None = self._cfg.gpu.get_option("warp_size")
+
             GpuPlatform: type
             match self._target:
                 case Target.CUDA:
@@ -486,6 +489,8 @@ class DefaultKernelCreationDriver:
 
             return GpuPlatform(
                 self._ctx,
+                assume_warp_aligned_block_size,
+                warp_size,
                 thread_mapping=thread_mapping,
             )
 
-- 
GitLab