diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 6f32102deafac5a40c37fe13bcb62bc5426107c0..873961cc75015fe1f654bf70e313e48aab102131 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -5,6 +5,8 @@ import operator from abc import ABC, abstractmethod from functools import reduce +from pystencils.types import PsBoolType + from ..ast import PsAstNode from ..constants import PsConstant from ...compound_op_mapping import compound_op_to_expr @@ -310,11 +312,9 @@ class CudaPlatform(GenericGpu): if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64): NotImplementedError("atomic operations are only available for float32/64 datatypes") - def gen_shuffle_instr(offset: int): - return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), - [PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))), - symbol_expr, - PsConstantExpr(PsConstant(offset, SInt(32)))]) + # set up mask symbol for active threads in warp + mask = PsSymbol("__shfl_mask", UInt(32)) + self._ctx.add_symbol(mask) # workaround for subtractions -> use additions for reducing intermediate results # similar to OpenMP reductions: local copies (negative sign) are added at the end @@ -325,8 +325,13 @@ class CudaPlatform(GenericGpu): actual_op = op # perform local warp reductions - num_shuffles = math.frexp(warp_size)[1] - 1 - shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i))) + def gen_shuffle_instr(offset: int): + return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), + [PsSymbolExpr(mask), symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) + + num_shuffles = math.frexp(warp_size)[1] + shuffles = [PsAssignment(symbol_expr, + compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) for i in reversed(range(1, num_shuffles))] # find first thread in warp @@ -343,14 +348,21 @@ class CudaPlatform(GenericGpu): PsConstantExpr(PsConstant(0, SInt(32)))) cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp + full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) + ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)), + [full_mask, is_valid_thread]) + decl_mask = PsDeclaration(PsSymbolExpr(mask), ballot_instr if is_valid_thread else full_mask) + # use atomic operation on first thread of warp to sync call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) call.args = (ptr_expr, symbol_expr) # assemble warp reduction - return PsBlock( - shuffles - + [PsConditional(cond, PsBlock([PsStatement(call)]))]) + return PsConditional(is_valid_thread if is_valid_thread else PsConstantExpr(PsLiteral("true", PsBoolType)), + PsBlock( + [decl_mask] + + shuffles + + [PsConditional(cond, PsBlock([PsStatement(call)]))])) # Internals @@ -359,7 +371,7 @@ class CudaPlatform(GenericGpu): def _get_condition_for_translation( self, ispace: IterationSpace): - if not self._omit_range_check: + if self._omit_range_check: return None match ispace: