Skip to content
Snippets Groups Projects
Commit ad292c2b authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Add initial version of warp-level reduction for CUDA

parent ce816539
1 merge request!438Reduction Support
Pipeline #74286 failed with stages
in 12 minutes and 45 seconds
from __future__ import annotations from __future__ import annotations
import math
import operator
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import reduce
from ..ast import PsAstNode from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
from ...types.quick import UInt, SInt
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu from .generic_gpu import GenericGpu
...@@ -17,14 +24,14 @@ from ..kernelcreation import ( ...@@ -17,14 +24,14 @@ from ..kernelcreation import (
) )
from ..kernelcreation.context import KernelCreationContext from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
from ..ast.expressions import ( from ..ast.expressions import (
PsExpression, PsExpression,
PsLiteralExpr, PsLiteralExpr,
PsCast, PsCast,
PsCall, PsCall,
PsLookup, PsLookup,
PsBufferAcc, PsSymbolExpr PsBufferAcc, PsSymbolExpr, PsConstantExpr, PsAdd, PsRem, PsEq
) )
from ..ast.expressions import PsLt, PsAnd from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType from ...types import PsSignedIntegerType, PsIeeeFloatType
...@@ -292,26 +299,58 @@ class CudaPlatform(GenericGpu): ...@@ -292,26 +299,58 @@ class CudaPlatform(GenericGpu):
case ReductionFunctions.WriteBackToPtr: case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args ptr_expr, symbol_expr = call.args
op = call.function.reduction_op op = call.function.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
warp_size = 32 # TODO: get from platform/user config
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) NotImplementedError("atomic operations are only available for float32/64 datatypes")
def gen_shuffle_instr(offset: int):
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))),
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32)))])
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op: match op:
case ReductionOp.Sub: case ReductionOp.Sub:
# workaround for unsupported atomicSub: use atomic add actual_op = ReductionOp.Add
# similar to OpenMP reductions: local copies (negative sign) are added at the end
call.function = CFunction("atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
case _: case _:
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype], actual_op = op
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr) # perform local warp reductions
num_shuffles = math.frexp(warp_size)[1] - 1
if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64): shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i)))
NotImplementedError("atomicMul is only available for float32/64 datatypes") for i in reversed(range(1, num_shuffles))]
return PsStatement(call) # find first thread in warp
ispace = self._ctx.get_iteration_space() # TODO: receive as argument in unfold_function?
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return PsBlock(
shuffles
+ [PsConditional(cond, PsBlock([PsStatement(call)]))])
# Internals # Internals
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment