Skip to content
Snippets Groups Projects

Reduction Support

Merged Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
5 files
+ 226
119
Preferences
Compare changes
Files
5
from __future__ import annotations
import math
from .generic_gpu import GenericGpu
from ..ast import PsAstNode
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
PsCall,
PsAnd,
PsConstantExpr,
PsSymbolExpr,
)
from ..ast.structural import (
PsConditional,
PsStatement,
PsAssignment,
PsBlock,
PsStructuralNode,
)
from ..constants import PsConstant
from ..exceptions import MaterializationError
from ..functions import NumericLimitsFunctions, CFunction
from ..literals import PsLiteral
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions import ReductionOp
from ...types import PsType, PsIeeeFloatType, PsCustomType, PsPointerType, PsScalarType
from ...types.quick import SInt, UInt
class CudaPlatform(GenericGpu):
"""Platform for the CUDA GPU taret."""
"""Platform for the CUDA GPU target."""
@property
def required_headers(self) -> set[str]:
return set()
return super().required_headers | {
'"npp.h"',
}
def resolve_reduction(
self,
ptr_expr: PsExpression,
symbol_expr: PsExpression,
reduction_op: ReductionOp,
) -> tuple[tuple[PsStructuralNode, ...], PsAstNode]:
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError(
"atomic operations are only available for float32/64 datatypes"
)
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match reduction_op:
case ReductionOp.Sub:
actual_reduction_op = ReductionOp.Add
case _:
actual_reduction_op = reduction_op
# check if thread is valid for performing reduction
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
cond: PsExpression
shuffles: tuple[PsAssignment, ...]
if self._warp_size and self._assume_warp_aligned_block_size:
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(
CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[
full_mask,
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32))),
],
)
# set up shuffle instructions for warp-level reduction
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(
PsAssignment(
symbol_expr,
compound_op_to_expr(
actual_reduction_op,
symbol_expr,
gen_shuffle_instr(pow(2, i - 1)),
),
)
for i in reversed(range(1, num_shuffles))
)
# find first thread in warp
first_thread_in_warp = self._first_thread_in_warp(ispace)
# set condition to only execute atomic operation on first valid thread in warp
cond = (
PsAnd(is_valid_thread, first_thread_in_warp)
if is_valid_thread
else first_thread_in_warp
)
else:
# no optimization: only execute atomic add on valid thread
shuffles = ()
cond = is_valid_thread
# use atomic operation
func = CFunction(
f"atomic{actual_reduction_op.name}", [ptrtype, stype], PsCustomType("void")
)
func_args = (ptr_expr, symbol_expr)
# assemble warp reduction
return shuffles, PsConditional(
cond, PsBlock([PsStatement(PsCall(func, func_args))])
)
def resolve_numeric_limits(
self, func: NumericLimitsFunctions, dtype: PsType
) -> PsExpression:
assert isinstance(dtype, PsIeeeFloatType)
match func:
case NumericLimitsFunctions.Min:
define = f"NPP_MINABS_{dtype.width}F"
case NumericLimitsFunctions.Max:
define = f"NPP_MAXABS_{dtype.width}F"
case _:
raise MaterializationError(
f"Cannot materialize call to function {func}"
)
return PsLiteralExpr(PsLiteral(define, dtype))