Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
2 files
+ 107
31
Preferences
Compare changes
Files
2
from __future__ import annotations
import math
import operator
from abc import ABC, abstractmethod
from functools import reduce
from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
from ...types.quick import UInt, SInt
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu
@@ -17,14 +24,14 @@ from ..kernelcreation import (
)
from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement
from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
PsCast,
PsCall,
PsLookup,
PsBufferAcc, PsSymbolExpr
PsBufferAcc, PsSymbolExpr, PsConstantExpr, PsAdd, PsRem, PsEq
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
@@ -292,28 +299,94 @@ class CudaPlatform(GenericGpu):
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
warp_size = 32 # TODO: get from platform/user config
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
def gen_shuffle_instr(offset: int):
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[PsLiteralExpr(PsLiteral("0xffffffff", UInt(32))),
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32)))])
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
# workaround for unsupported atomicSub: use atomic add
# similar to OpenMP reductions: local copies (negative sign) are added at the end
call.function = CFunction("atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
actual_op = ReductionOp.Add
case _:
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
actual_op = op
# perform local warp reductions
num_shuffles = math.frexp(warp_size)[1] - 1
shuffles = [PsAssignment(symbol_expr, compound_op_to_expr(actual_op, symbol_expr, gen_shuffle_instr(i)))
for i in reversed(range(1, num_shuffles))]
# find first thread in warp
ispace = self._ctx.get_iteration_space() # TODO: receive as argument in unfold_function?
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return PsBlock(
shuffles
+ [PsConditional(cond, PsBlock([PsStatement(call)]))])
if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64):
NotImplementedError("atomicMul is only available for float32/64 datatypes")
# Internals
return PsStatement(call)
# TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU?
# Internals
def _get_condition_for_translation(
self, ispace: IterationSpace):
if not self._omit_range_check:
return None
match ispace:
case FullIterationSpace():
dimensions = ispace.dimensions_in_loop_order()
conds = []
for dim in dimensions:
ctr_expr = PsExpression.make(dim.counter)
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return condition
else:
return None
case SparseIterationSpace():
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
stop = PsExpression.make(ispace.index_list.shape[0])
return PsLt(sparse_ctr_expr.clone(), stop)
case _:
assert False, "Unknown iteration space"
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
@@ -321,7 +394,7 @@ class CudaPlatform(GenericGpu):
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
conds = []
cond = self._get_condition_for_translation(ispace)
dimensions = ispace.dimensions_in_loop_order()
@@ -335,14 +408,9 @@ class CudaPlatform(GenericGpu):
indexing_decls.append(
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
)
if not self._omit_range_check:
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
if cond:
ast = PsBlock(indexing_decls + [PsConditional(cond, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
@@ -355,6 +423,8 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
cond = self._get_condition_for_translation(ispace)
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
ctr_mapping = self._thread_mapping(ispace)
@@ -377,10 +447,8 @@ class CudaPlatform(GenericGpu):
]
body.statements = mappings + body.statements
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
if cond:
ast = PsBlock([sparse_idx_decl, PsConditional(cond, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
ast = body