Skip to content
Snippets Groups Projects
Commit 90837d04 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Merge handling for GPU reductions into generic_gpu.py for the time being

parent 0d7526c0
Branches
1 merge request!438Reduction Support
Pipeline #76840 failed with stages
in 17 minutes and 30 seconds
from __future__ import annotations
from abc import ABC, abstractmethod
from ...types import constify, deconstify
import math
import operator
from abc import ABC, abstractmethod
from functools import reduce
from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
from ...types.quick import UInt, SInt
from ..exceptions import MaterializationError
from .platform import Platform
......@@ -15,7 +24,7 @@ from ..kernelcreation import (
)
from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration
from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
......@@ -23,12 +32,17 @@ from ..ast.expressions import (
PsCall,
PsLookup,
PsBufferAcc,
PsSymbolExpr,
PsConstantExpr,
PsAdd,
PsRem,
PsEq
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction
from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \
PsMathFunction
int32 = PsSignedIntegerType(width=32, const=False)
......@@ -174,10 +188,15 @@ class GenericGpu(Platform):
def __init__(
self,
ctx: KernelCreationContext,
assume_warp_aligned_block_size: bool,
warp_size: int | None,
thread_mapping: ThreadMapping | None = None,
) -> None:
super().__init__(ctx)
self._assume_warp_aligned_block_size = assume_warp_aligned_block_size
self._warp_size = warp_size
self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping()
)
......@@ -194,14 +213,107 @@ class GenericGpu(Platform):
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction)
@staticmethod
def _get_condition_for_translation(ispace: IterationSpace):
if isinstance(ispace, FullIterationSpace):
conds = []
dimensions = ispace.dimensions_in_loop_order()
for dim in dimensions:
ctr_expr = PsExpression.make(dim.counter)
conds.append(PsLt(ctr_expr, dim.stop))
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return condition
elif isinstance(ispace, SparseIterationSpace):
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
stop = PsExpression.make(ispace.index_list.shape[0])
return PsLt(sparse_ctr_expr.clone(), stop)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
func = call_func.func
if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
actual_op = ReductionOp.Add
case _:
actual_op = op
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(PsAssignment(symbol_expr,
compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles)))
# find first thread in warp
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
func = call.function.func
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType):
if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
assert isinstance(dtype, PsIeeeFloatType)
match func:
case NumericLimitsFunctions.Min:
define = "NEG_INFINITY"
case NumericLimitsFunctions.Max:
define = "POS_INFINITY"
case _:
raise MaterializationError(f"Cannot materialize call to function {func}")
return PsLiteralExpr(PsLiteral(define, dtype))
if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions:
match func:
case (
MathFunctions.Exp
......@@ -262,7 +374,6 @@ class GenericGpu(Platform):
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
conds = []
dimensions = ispace.dimensions_in_loop_order()
......@@ -276,14 +387,9 @@ class GenericGpu(Platform):
indexing_decls.append(
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
)
conds.append(PsLt(ctr_expr, dim.stop))
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
return ast
cond = self._get_condition_for_translation(ispace)
return PsBlock(indexing_decls + [PsConditional(cond, body)])
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
......@@ -313,8 +419,5 @@ class GenericGpu(Platform):
]
body.statements = mappings + body.statements
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
return ast
cond = self._get_condition_for_translation(ispace)
return PsBlock([sparse_idx_decl, PsConditional(cond, body)])
......@@ -475,6 +475,9 @@ class DefaultKernelCreationDriver:
else None
)
assume_warp_aligned_block_size: bool = self._cfg.gpu.get_option("assume_warp_aligned_block_size")
warp_size: int | None = self._cfg.gpu.get_option("warp_size")
GpuPlatform: type
match self._target:
case Target.CUDA:
......@@ -486,6 +489,8 @@ class DefaultKernelCreationDriver:
return GpuPlatform(
self._ctx,
assume_warp_aligned_block_size,
warp_size,
thread_mapping=thread_mapping,
)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment