Skip to content
Snippets Groups Projects
Commit 806dcb6b authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Move resolution of reductions to concrete gpu platform classes

parent 4e5c89b9
1 merge request!438Reduction Support
Pipeline #77395 passed with stages
in 15 minutes and 43 seconds
from __future__ import annotations
import math
from .generic_gpu import GenericGpu
from ..ast.expressions import PsExpression, PsLiteralExpr
from ..ast import PsAstNode
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
PsCall,
PsAnd,
PsConstantExpr,
PsSymbolExpr,
)
from ..ast.structural import (
PsConditional,
PsStatement,
PsAssignment,
PsBlock,
PsStructuralNode,
)
from ..constants import PsConstant
from ..exceptions import MaterializationError
from ..functions import NumericLimitsFunctions
from ..functions import NumericLimitsFunctions, CFunction
from ..literals import PsLiteral
from ...types import PsType, PsIeeeFloatType
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions import ReductionOp
from ...types import PsType, PsIeeeFloatType, PsCustomType, PsPointerType, PsScalarType
from ...types.quick import SInt, UInt
class CudaPlatform(GenericGpu):
......@@ -17,7 +38,92 @@ class CudaPlatform(GenericGpu):
'"npp.h"',
}
def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression:
def resolve_reduction(
self,
ptr_expr: PsExpression,
symbol_expr: PsExpression,
reduction_op: ReductionOp,
) -> tuple[tuple[PsStructuralNode, ...], PsAstNode]:
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptrtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(stype, PsScalarType)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError(
"atomic operations are only available for float32/64 datatypes"
)
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match reduction_op:
case ReductionOp.Sub:
actual_reduction_op = ReductionOp.Add
case _:
actual_reduction_op = reduction_op
# check if thread is valid for performing reduction
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
cond: PsExpression
shuffles: tuple[PsAssignment, ...]
if self._warp_size and self._assume_warp_aligned_block_size:
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(
CFunction("__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype),
[
full_mask,
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32))),
],
)
# set up shuffle instructions for warp-level reduction
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(
PsAssignment(
symbol_expr,
compound_op_to_expr(
actual_reduction_op,
symbol_expr,
gen_shuffle_instr(pow(2, i - 1)),
),
)
for i in reversed(range(1, num_shuffles))
)
# find first thread in warp
first_thread_in_warp = self._first_thread_in_warp(ispace)
# set condition to only execute atomic operation on first valid thread in warp
cond = (
PsAnd(is_valid_thread, first_thread_in_warp)
if is_valid_thread
else first_thread_in_warp
)
else:
# no optimization: only execute atomic add on valid thread
shuffles = ()
cond = is_valid_thread
# use atomic operation
func = CFunction(
f"atomic{actual_reduction_op.name}", [ptrtype, stype], PsCustomType("void")
)
func_args = (ptr_expr, symbol_expr)
# assemble warp reduction
return shuffles, PsConditional(
cond, PsBlock([PsStatement(PsCall(func, func_args))])
)
def resolve_numeric_limits(
self, func: NumericLimitsFunctions, dtype: PsType
) -> PsExpression:
assert isinstance(dtype, PsIeeeFloatType)
match func:
......
from __future__ import annotations
import math
import operator
from abc import ABC, abstractmethod
from functools import reduce
from ..ast import PsAstNode
from ..constants import PsConstant
from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.reduction import ReductionOp
from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType, PsType
from ...types.quick import UInt, SInt
from ...types import (
constify,
deconstify,
PsScalarType,
PsType,
)
from ...types.quick import SInt
from ..exceptions import MaterializationError
from .platform import Platform
......@@ -28,8 +31,6 @@ from ..ast.structural import (
PsBlock,
PsConditional,
PsDeclaration,
PsStatement,
PsAssignment,
PsStructuralNode,
)
from ..ast.expressions import (
......@@ -39,7 +40,6 @@ from ..ast.expressions import (
PsCall,
PsLookup,
PsBufferAcc,
PsSymbolExpr,
PsConstantExpr,
PsAdd,
PsRem,
......@@ -206,7 +206,18 @@ class GenericGpu(Platform):
}
@abstractmethod
def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression:
def resolve_numeric_limits(
self, func: NumericLimitsFunctions, dtype: PsType
) -> PsExpression:
pass
@abstractmethod
def resolve_reduction(
self,
ptr_expr: PsExpression,
symbol_expr: PsExpression,
reduction_op: ReductionOp,
) -> tuple[tuple[PsStructuralNode, ...], PsAstNode]:
pass
def __init__(
......@@ -262,6 +273,31 @@ class GenericGpu(Platform):
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
@staticmethod
def _thread_index_per_dim(ispace: IterationSpace) -> tuple[PsExpression, ...]:
"""Returns thread indices multiplied with block dimension strides per dimension."""
return tuple(
idx
* PsConstantExpr(
PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))
)
for i, idx in enumerate(THREAD_IDX[: ispace.rank])
)
def _first_thread_in_warp(self, ispace: IterationSpace) -> PsExpression:
"""Returns expression that determines whether a thread is the first within a warp."""
tids_per_dim = GenericGpu._thread_index_per_dim(ispace)
tid: PsExpression = tids_per_dim[0]
for t in tids_per_dim[1:]:
tid = PsAdd(tid, t)
return PsEq(
PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))),
)
def select_function(
self, call: PsCall
) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
......@@ -276,97 +312,8 @@ class GenericGpu(Platform):
):
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
stype = symbol_expr.dtype
ptrtype = ptr_expr.dtype
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(
ptrtype, PsPointerType
)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(
stype, PsScalarType
)
if not isinstance(stype, PsIeeeFloatType) or stype.width not in (32, 64):
NotImplementedError(
"atomic operations are only available for float32/64 datatypes"
)
# workaround for subtractions -> use additions for reducing intermediate results
# similar to OpenMP reductions: local copies (negative sign) are added at the end
match op:
case ReductionOp.Sub:
actual_op = ReductionOp.Add
case _:
actual_op = op
# check if thread is valid for performing reduction
ispace = self._ctx.get_iteration_space()
is_valid_thread = self._get_condition_for_translation(ispace)
cond: PsExpression
shuffles: tuple[PsAssignment, ...]
if self._warp_size and self._assume_warp_aligned_block_size:
# perform local warp reductions
def gen_shuffle_instr(offset: int):
full_mask = PsLiteralExpr(PsLiteral("0xffffffff", UInt(32)))
return PsCall(
CFunction(
"__shfl_xor_sync", [UInt(32), stype, SInt(32)], stype
),
[
full_mask,
symbol_expr,
PsConstantExpr(PsConstant(offset, SInt(32))),
],
)
# set up shuffle instructions for warp-level reduction
num_shuffles = math.frexp(self._warp_size)[1]
shuffles = tuple(
PsAssignment(
symbol_expr,
compound_op_to_expr(
actual_op, symbol_expr, gen_shuffle_instr(pow(2, i - 1))
),
)
for i in reversed(range(1, num_shuffles))
)
# find first thread in warp
thread_indices_per_dim = [
idx
* PsConstantExpr(
PsConstant(reduce(operator.mul, BLOCK_DIM[:i], 1), SInt(32))
)
for i, idx in enumerate(THREAD_IDX[: ispace.rank])
]
tid: PsExpression = thread_indices_per_dim[0]
for t in thread_indices_per_dim[1:]:
tid = PsAdd(tid, t)
first_thread_in_warp = PsEq(
PsRem(tid, PsConstantExpr(PsConstant(self._warp_size, SInt(32)))),
PsConstantExpr(PsConstant(0, SInt(32))),
)
# set condition to only execute atomic operation on first valid thread in warp
cond = (
PsAnd(is_valid_thread, first_thread_in_warp)
if is_valid_thread
else first_thread_in_warp
)
else:
# no optimization: only execute atomic add on valid thread
shuffles = ()
cond = is_valid_thread
# use atomic operation
call.function = CFunction(
f"atomic{actual_op.name}", [ptrtype, stype], PsCustomType("void")
)
call.args = (ptr_expr, symbol_expr)
# assemble warp reduction
return shuffles, PsConditional(cond, PsBlock([PsStatement(call)]))
return self.resolve_reduction(ptr_expr, symbol_expr, op)
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
......
from __future__ import annotations
from .generic_gpu import GenericGpu
from ..ast import PsAstNode
from ..ast.expressions import PsExpression, PsLiteralExpr
from ..ast.structural import PsStructuralNode
from ..exceptions import MaterializationError
from ..functions import NumericLimitsFunctions
from ..literals import PsLiteral
from ...sympyextensions import ReductionOp
from ...types import PsType, PsIeeeFloatType
......@@ -12,12 +16,11 @@ class HipPlatform(GenericGpu):
@property
def required_headers(self) -> set[str]:
return super().required_headers | {
'"pystencils_runtime/hip.h"',
"<limits>"
}
return super().required_headers | {'"pystencils_runtime/hip.h"', "<limits>"}
def resolve_numeric_limits(self, func: NumericLimitsFunctions, dtype: PsType) -> PsExpression:
def resolve_numeric_limits(
self, func: NumericLimitsFunctions, dtype: PsType
) -> PsExpression:
assert isinstance(dtype, PsIeeeFloatType)
return PsLiteralExpr(
......@@ -26,3 +29,12 @@ class HipPlatform(GenericGpu):
dtype,
)
)
def resolve_reduction(
self,
ptr_expr: PsExpression,
symbol_expr: PsExpression,
reduction_op: ReductionOp,
) -> tuple[tuple[PsStructuralNode, ...], PsAstNode]:
raise MaterializationError("Reductions are yet not supported in HIP backend.")
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment