Skip to content
Snippets Groups Projects
Commit 90837d04 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Merge handling for GPU reductions into generic_gpu.py for the time being

parent 0d7526c0
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #76840 failed
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod
from ...types import constify, deconstify import math
import operator
from abc import ABC, abstractmethod
from functools import reduce
from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType
from ...types.quick import UInt, SInt
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from .platform import Platform from .platform import Platform
...@@ -15,7 +24,7 @@ from ..kernelcreation import ( ...@@ -15,7 +24,7 @@ from ..kernelcreation import (
) )
from ..kernelcreation.context import KernelCreationContext from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration from ..ast.structural import PsBlock, PsConditional, PsDeclaration, PsStatement, PsAssignment
from ..ast.expressions import ( from ..ast.expressions import (
PsExpression, PsExpression,
PsLiteralExpr, PsLiteralExpr,
...@@ -23,12 +32,17 @@ from ..ast.expressions import ( ...@@ -23,12 +32,17 @@ from ..ast.expressions import (
PsCall, PsCall,
PsLookup, PsLookup,
PsBufferAcc, PsBufferAcc,
PsSymbolExpr,
PsConstantExpr,
PsAdd,
PsRem,
PsEq
) )
from ..ast.expressions import PsLt, PsAnd from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction from ..functions import MathFunctions, CFunction, ReductionFunctions, NumericLimitsFunctions, PsReductionFunction, \
PsMathFunction
int32 = PsSignedIntegerType(width=32, const=False) int32 = PsSignedIntegerType(width=32, const=False)
...@@ -174,10 +188,15 @@ class GenericGpu(Platform): ...@@ -174,10 +188,15 @@ class GenericGpu(Platform):
def __init__( def __init__(
self, self,
ctx: KernelCreationContext, ctx: KernelCreationContext,
assume_warp_aligned_block_size: bool,
warp_size: int | None,
thread_mapping: ThreadMapping | None = None, thread_mapping: ThreadMapping | None = None,
) -> None: ) -> None:
super().__init__(ctx) super().__init__(ctx)
self._assume_warp_aligned_block_size = assume_warp_aligned_block_size
self._warp_size = warp_size
self._thread_mapping = ( self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping() thread_mapping if thread_mapping is not None else Linear3DMapping()
) )
...@@ -194,14 +213,107 @@ class GenericGpu(Platform): ...@@ -194,14 +213,107 @@ class GenericGpu(Platform):
else: else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}") raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression: @staticmethod
assert isinstance(call.function, PsMathFunction) def _get_condition_for_translation(ispace: IterationSpace):
if isinstance(ispace, FullIterationSpace):
conds = []
dimensions = ispace.dimensions_in_loop_order()
for dim in dimensions:
ctr_expr = PsExpression.make(dim.counter)
conds.append(PsLt(ctr_expr, dim.stop))
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return condition
elif isinstance(ispace, SparseIterationSpace):
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
stop = PsExpression.make(ispace.index_list.shape[0])
return PsLt(sparse_ctr_expr.clone(), stop)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
func = call_func.func
if isinstance(call_func, PsReductionFunction) and func is ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError("atomic operations are only available for float32/64 datatypes")
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
actual_op = ReductionOp.Add
case _:
actual_op = op
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[full_mask, symbol_expr, PsConstantExpr(PsConstant(offset, SInt(32)))])
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(PsAssignment(symbol_expr,
compound_op_to_expr(actual_op,
symbol_expr, gen_shuffle_instr(pow(2, i - 1))))
for i in reversed(range(1, num_shuffles)))
# find first thread in warp
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
thread_indices_per_dim = [
idx * PsConstantExpr(PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32)))
for i, idx in enumerate(THREAD_IDX[:ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))))
cond = PsAnd(is_valid_thread, first_thread_in_warp) if is_valid_thread else first_thread_in_warp
# use atomic operation on first thread of warp to sync
call.function = CFunction(f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
func = call.function.func
dtype = call.get_dtype() dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType): if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions:
assert isinstance(dtype, PsIeeeFloatType)
match func:
case NumericLimitsFunctions.Min:
define = "NEG_INFINITY"
case NumericLimitsFunctions.Max:
define = "POS_INFINITY"
case _:
raise MaterializationError(f"Cannot materialize call to function {func}")
return PsLiteralExpr(PsLiteral(define, dtype))
if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions:
match func: match func:
case ( case (
MathFunctions.Exp MathFunctions.Exp
...@@ -262,7 +374,6 @@ class GenericGpu(Platform): ...@@ -262,7 +374,6 @@ class GenericGpu(Platform):
ctr_mapping = self._thread_mapping(ispace) ctr_mapping = self._thread_mapping(ispace)
indexing_decls = [] indexing_decls = []
conds = []
dimensions = ispace.dimensions_in_loop_order() dimensions = ispace.dimensions_in_loop_order()
...@@ -276,14 +387,9 @@ class GenericGpu(Platform): ...@@ -276,14 +387,9 @@ class GenericGpu(Platform):
indexing_decls.append( indexing_decls.append(
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter])) self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
) )
conds.append(PsLt(ctr_expr, dim.stop))
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
return ast cond = self._get_condition_for_translation(ispace)
return PsBlock(indexing_decls + [PsConditional(cond, body)])
def _prepend_sparse_translation( def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace self, body: PsBlock, ispace: SparseIterationSpace
...@@ -313,8 +419,5 @@ class GenericGpu(Platform): ...@@ -313,8 +419,5 @@ class GenericGpu(Platform):
] ]
body.statements = mappings + body.statements body.statements = mappings + body.statements
stop = PsExpression.make(ispace.index_list.shape[0]) cond = self._get_condition_for_translation(ispace)
condition = PsLt(sparse_ctr_expr.clone(), stop) return PsBlock([sparse_idx_decl, PsConditional(cond, body)])
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
return ast
...@@ -475,6 +475,9 @@ class DefaultKernelCreationDriver: ...@@ -475,6 +475,9 @@ class DefaultKernelCreationDriver:
else None else None
) )
assume_warp_aligned_block_size: bool = self._cfg.gpu.get_option("assume_warp_aligned_block_size")
warp_size: int | None = self._cfg.gpu.get_option("warp_size")
GpuPlatform: type GpuPlatform: type
match self._target: match self._target:
case Target.CUDA: case Target.CUDA:
...@@ -486,6 +489,8 @@ class DefaultKernelCreationDriver: ...@@ -486,6 +489,8 @@ class DefaultKernelCreationDriver:
return GpuPlatform( return GpuPlatform(
self._ctx, self._ctx,
assume_warp_aligned_block_size,
warp_size,
thread_mapping=thread_mapping, thread_mapping=thread_mapping,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment