From bcd83842f3d331fa039a96e3eab68c34b3beb6f3 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Fri, 21 Feb 2025 14:56:46 +0100
Subject: [PATCH] Use full mask for CUDA reductions

---
 src/pystencils/backend/platforms/cuda.py | 23 +++++++++++------------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 873961cc7..8936bf73f 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -313,8 +313,9 @@ class CudaPlatform(GenericGpu):
                     NotImplementedError("atomic operations are only available for float32/64 datatypes")
 
                 # set up mask symbol for active threads in warp
-                mask = PsSymbol("__shfl_mask", UInt(32))
-                self._ctx.add_symbol(mask)
+                #mask = PsSymbol("__shfl_mask", UInt(32))
+                #self._ctx.add_symbol(mask)
+                full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
 
                 # workaround for subtractions -> use additions for reducing intermediate results
                 # similar to OpenMP reductions: local copies (negative sign) are added at the end
@@ -327,7 +328,7 @@ class CudaPlatform(GenericGpu):
                 # perform local warp reductions
                 def gen_shuffle_instr(offset: int):
                     return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
-                                  [PsSymbolExpr(mask), symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
+                                  [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
 
                 num_shuffles = math.frexp(warp_size)[1]
                 shuffles = [PsAssignment(symbol_expr,
@@ -348,21 +349,19 @@ class CudaPlatform(GenericGpu):
                                             PsConstantExpr(PsConstant(0, SInt(32))))
                 cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
 
-                full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
-                ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)),
-                                                 [full_mask, is_valid_thread])
-                decl_mask = PsDeclaration(PsSymbolExpr(mask), ballot_instr if is_valid_thread else full_mask)
+                #ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)),
+                #                                 [full_mask, is_valid_thread])
+                #decl_mask = PsDeclaration(full_mask)
 
                 # use atomic operation on first thread of warp to sync
                 call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
                 call.args = (ptr_expr, symbol_expr)
 
                 # assemble warp reduction
-                return PsConditional(is_valid_thread if is_valid_thread else PsConstantExpr(PsLiteral("true", PsBoolType)),
-                    PsBlock(
-                    [decl_mask]
-                    + shuffles
-                    + [PsConditional(cond, PsBlock([PsStatement(call)]))]))
+                return PsBlock(
+                    #[decl_mask]
+                    shuffles
+                    + [PsConditional(cond, PsBlock([PsStatement(call)]))])
 
     #   Internals
 
-- 
GitLab