From 9ec813bdd1c8dfd2a1f32bb300d6f9e7d172542f Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 20 Mar 2025 16:07:01 +0100
Subject: [PATCH] Employ optimized warp-level reduction based on check

---
 .../backend/platforms/generic_gpu.py          | 62 +++++++++++--------
 1 file changed, 37 insertions(+), 25 deletions(-)

diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index f16c28e8c..d3e8de42d 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -264,33 +264,45 @@ class GenericGpu(Platform):
                 case _:
                     actual_op = op
 
-            # perform local warp reductions
-            def gen_shuffle_instr(offset: int):
-                full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
-                return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
-                              [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
-
-            num_shuffles = math.frexp(self._warp_size)[1]
-            shuffles = tuple(PsAssignment(symbol_expr,
-                                          compound_op_to_expr(actual_op,
-                                                              symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
-                             for i in reversed(range(1, num_shuffles)))
-
-            # find first thread in warp
+            # check if thread is valid for performing reduction
             ispace = self._ctx.get_iteration_space()
             is_valid_thread = self._get_condition_for_translation(ispace)
-            thread_indices_per_dim = [
-                idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
-                for i, idx in enumerate(THREAD_IDX[:ispace.rank])
-            ]
-            tid: PsExpression = thread_indices_per_dim[0]
-            for t in thread_indices_per_dim[1:]:
-                tid = PsAdd(tid, t)
-            first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
-                                        PsConstantExpr(PsConstant(0, SInt(32))))
-            cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
-
-            # use atomic operation on first thread of warp to sync
+
+            cond: PsExpression
+            shuffles: tuple[PsAssignment, ...]
+            if self._warp_size and self._assume_warp_aligned_block_size:
+                # perform local warp reductions
+                def gen_shuffle_instr(offset: int):
+                    full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
+                    return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
+                                  [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
+
+                # set up shuffle instructions for warp-level reduction
+                num_shuffles = math.frexp(self._warp_size)[1]
+                shuffles = tuple(PsAssignment(symbol_expr,
+                                              compound_op_to_expr(actual_op,
+                                                                  symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
+                                 for i in reversed(range(1, num_shuffles)))
+
+                # find first thread in warp
+                thread_indices_per_dim = [
+                    idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
+                    for i, idx in enumerate(THREAD_IDX[:ispace.rank])
+                ]
+                tid: PsExpression = thread_indices_per_dim[0]
+                for t in thread_indices_per_dim[1:]:
+                    tid = PsAdd(tid, t)
+                first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
+                                            PsConstantExpr(PsConstant(0, SInt(32))))
+
+                # set condition to only execute atomic operation on first valid thread in warp
+                cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
+            else:
+                # no optimization: only execute atomic add on valid thread
+                shuffles = ()
+                cond = is_valid_thread
+
+            # use atomic operation
             call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
             call.args = (ptr_expr, symbol_expr)
 
-- 
GitLab