diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index f16c28e8c2534e9e9e03cb7e38bf3dea76543c2b..d3e8de42db572c7cc2cdfa807f777ccd4bd76a89 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -264,33 +264,45 @@ class GenericGpu(Platform): case _: actual_op = op - # perform local warp reductions - def gen_shuffle_instr(offset: int): - full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) - return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), - [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) - - num_shuffles = math.frexp(self._warp_size)[1] - shuffles = tuple(PsAssignment(symbol_expr, - compound_op_to_expr(actual_op, - symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) - for i in reversed(range(1, num_shuffles))) - - # find first thread in warp + # check if thread is valid for performing reduction ispace = self._ctx.get_iteration_space() is_valid_thread = self._get_condition_for_translation(ispace) - thread_indices_per_dim = [ - idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) - for i, idx in enumerate(THREAD_IDX[:ispace.rank]) - ] - tid: PsExpression = thread_indices_per_dim[0] - for t in thread_indices_per_dim[1:]: - tid = PsAdd(tid, t) - first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))), - PsConstantExpr(PsConstant(0, SInt(32)))) - cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp - - # use atomic operation on first thread of warp to sync + + cond: PsExpression + shuffles: tuple[PsAssignment, ...] + if self._warp_size and self._assume_warp_aligned_block_size: + # perform local warp reductions + def gen_shuffle_instr(offset: int): + full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))) + return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype), + [full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))]) + + # set up shuffle instructions for warp-level reduction + num_shuffles = math.frexp(self._warp_size)[1] + shuffles = tuple(PsAssignment(symbol_expr, + compound_op_to_expr(actual_op, + symbol_expr, gen_shuffle_instr(pow(2, i - 1)))) + for i in reversed(range(1, num_shuffles))) + + # find first thread in warp + thread_indices_per_dim = [ + idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))) + for i, idx in enumerate(THREAD_IDX[:ispace.rank]) + ] + tid: PsExpression = thread_indices_per_dim[0] + for t in thread_indices_per_dim[1:]: + tid = PsAdd(tid, t) + first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))), + PsConstantExpr(PsConstant(0, SInt(32)))) + + # set condition to only execute atomic operation on first valid thread in warp + cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp + else: + # no optimization: only execute atomic add on valid thread + shuffles = () + cond = is_valid_thread + + # use atomic operation call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")) call.args = (ptr_expr, symbol_expr)