Skip to content
Snippets Groups Projects
Commit 2b6589b8 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Introduce masks for warp reductions and fix errors when shuffling warp results

parent c7b564be
1 merge request!438Reduction Support
Pipeline #74295 failed with stages
in 7 minutes and 27 seconds
......@@ -5,6 +5,8 @@ import operator
from abc import ABC, abstractmethod
from functools import reduce
from pystencils.types import PsBoolType
from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
......@@ -310,11 +312,9 @@ class CudaPlatform(GenericGpu):
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
def gen_shuffle_instr(offset: int):
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))),
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32)))])
# set up mask symbol for active threads in warp
mask = PsSymbol("__shfl_mask", UInt(32))
self._ctx.add_symbol(mask)
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
......@@ -325,8 +325,13 @@ class CudaPlatform(GenericGpu):
actual_op = op
# perform local warp reductions
num_shuffles = math.frexp(warp_size)[1] - 1
shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i)))
def gen_shuffle_instr(offset: int):
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[PsSymbolExpr(mask), symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(warp_size)[1]
shuffles = [PsAssignment(symbol_expr,
compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles))]
# find first thread in warp
......@@ -343,14 +348,21 @@ class CudaPlatform(GenericGpu):
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
ballot_instr = PsCall(CFunction("__ballot_sync", [UInt(32), SInt(32)], SInt(32)),
[full_mask, is_valid_thread])
decl_mask = PsDeclaration(PsSymbolExpr(mask), ballot_instr if is_valid_thread else full_mask)
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return PsBlock(
shuffles
+ [PsConditional(cond, PsBlock([PsStatement(call)]))])
return PsConditional(is_valid_thread if is_valid_thread else PsConstantExpr(PsLiteral("true", PsBoolType)),
PsBlock(
[decl_mask]
+ shuffles
+ [PsConditional(cond, PsBlock([PsStatement(call)]))]))
# Internals
......@@ -359,7 +371,7 @@ class CudaPlatform(GenericGpu):
def _get_condition_for_translation(
self, ispace: IterationSpace):
if not self._omit_range_check:
if self._omit_range_check:
return None
match ispace:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment