Skip to content
Snippets Groups Projects
Commit ad292c2b authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Add initial version of warp-level reduction for CUDA

parent ce816539
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #74286 failed
from __future__ import annotations
import math
import operator
from abc import ABC, abstractmethod
from functools import reduce
from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
from ...types.quick import UInt, SInt
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu
......@@ -17,14 +24,14 @@ from ..kernelcreation import (
)
from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement
from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
PsCast,
PsCall,
PsLookup,
PsBufferAcc, PsSymbolExpr
PsBufferAcc, PsSymbolExpr, PsConstantExpr, PsAdd, PsRem, PsEq
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
......@@ -292,26 +299,58 @@ class CudaPlatform(GenericGpu):
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
warp_size = 32 # TODO: get from platform/user config
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
def gen_shuffle_instr(offset: int):
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))),
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32)))])
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
# workaround for unsupported atomicSub: use atomic add
# similar to OpenMP reductions: local copies (negative sign) are added at the end
call.function = CFunction("atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
actual_op = ReductionOp.Add
case _:
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64):
NotImplementedError("atomicMul is only available for float32/64 datatypes")
return PsStatement(call)
actual_op = op
# perform local warp reductions
num_shuffles = math.frexp(warp_size)[1] - 1
shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i)))
for i in reversed(range(1, num_shuffles))]
# find first thread in warp
ispace = self._ctx.get_iteration_space() # TODO: receive as argument in unfold_function?
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return PsBlock(
shuffles
+ [PsConditional(cond, PsBlock([PsStatement(call)]))])
# Internals
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment