Skip to content
Snippets Groups Projects
Commit 9ec813bd authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Employ optimized warp-level reduction based on check

parent 90837d04
Branches
1 merge request!438Reduction Support
......@@ -264,33 +264,45 @@ class GenericGpu(Platform):
case _:
actual_op = op
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(PsAssignment(symbol_expr,
compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles)))
# find first thread in warp
# check if thread is valid for performing reduction
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
cond: PsExpression
shuffles: tuple[PsAssignment, ...]
if self._warp_size and self._assume_warp_aligned_block_size:
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
# set up shuffle instructions for warp-level reduction
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(PsAssignment(symbol_expr,
compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles)))
# find first thread in warp
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
# set condition to only execute atomic operation on first valid thread in warp
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
else:
# no optimization: only execute atomic add on valid thread
shuffles = ()
cond = is_valid_thread
# use atomic operation
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment