diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 873961cc75015fe1f654bf70e313e48aab102131..8936bf73f90f5950a8a181d403f312a65e24ee07 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -313,8 +313,9 @@ class CudaPlatform(GenericGpu): NotImplementedError("atomic operations are only available for float32/64 datatypes") # set up mask symbol for active threads in warp - mask = PsSymbol("__shfl_mask", UInt(32)) - self._ctx.add_symbol(mask) + #mask = PsSymbol("__shfl_mask", UInt(32)) + #self._ctx.add_symbol(mask) + full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) # workaround for subtractions -> use additions for reducing intermediate results # similar to OpenMP reductions: local copies (negative sign) are added at the end @@ -327,7 +328,7 @@ class CudaPlatform(GenericGpu): # perform local warp reductions def gen_shuffle_instr(offset: int): return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), - [PsSymbolExpr(mask), symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) + [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) num_shuffles = math.frexp(warp_size)[1] shuffles = [PsAssignment(symbol_expr, @@ -348,21 +349,19 @@ class CudaPlatform(GenericGpu): PsConstantExpr(PsConstant(0, SInt(32)))) cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp - full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) - ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)), - [full_mask, is_valid_thread]) - decl_mask = PsDeclaration(PsSymbolExpr(mask), ballot_instr if is_valid_thread else full_mask) + #ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)), + # [full_mask, is_valid_thread]) + #decl_mask = PsDeclaration(full_mask) # use atomic operation on first thread of warp to sync call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) call.args = (ptr_expr, symbol_expr) # assemble warp reduction - return PsConditional(is_valid_thread if is_valid_thread else PsConstantExpr(PsLiteral("true", PsBoolType)), - PsBlock( - [decl_mask] - + shuffles - + [PsConditional(cond, PsBlock([PsStatement(call)]))])) + return PsBlock( + #[decl_mask] + shuffles + + [PsConditional(cond, PsBlock([PsStatement(call)]))]) # Internals