-
Frederik Hennig authoredFrederik Hennig authored
cuda.py 11.30 KiB
from __future__ import annotations
from abc import ABC, abstractmethod
from ...types import constify, deconstify
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu
from ..memory import PsSymbol
from ..kernelcreation import (
Typifier,
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
AstFactory,
)
from ..kernelcreation.context import KernelCreationContext
from ..ast.structural import PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import (
PsExpression,
PsLiteralExpr,
PsCast,
PsCall,
PsLookup,
PsBufferAcc,
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction
int32 = PsSignedIntegerType(width=32, const=False)
BLOCK_IDX = [
PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
]
THREAD_IDX = [
PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
]
BLOCK_DIM = [
PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
]
GRID_DIM = [
PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
]
class ThreadMapping(ABC):
@abstractmethod
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
"""Map the current thread index onto a point in the given iteration space.
Implementations of this method must return a declaration for each dimension counter
of the given iteration space.
"""
class Linear3DMapping(ThreadMapping):
"""3D globally linearized mapping, where each thread is assigned a work item according to
its location in the global launch grid."""
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 3:
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
"using the Linear3D GPU thread index mapping."
)
dimensions = ispace.dimensions_in_loop_order()
idx_map: dict[PsSymbol, PsExpression] = dict()
for coord, dim in enumerate(dimensions[::-1]):
tid = self._linear_thread_idx(coord)
idx_map[dim.counter] = dim.start + dim.step * PsCast(
deconstify(dim.counter.get_dtype()), tid
)
return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._linear_thread_idx(0)
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord]
thread_idx = THREAD_IDX[coord]
return block_idx * block_size + thread_idx
class Blockwise4DMapping(ThreadMapping):
"""Blockwise index mapping for up to 4D iteration spaces, where the outer three dimensions
are mapped to block indices."""
_indices_fastest_first = [ # slowest to fastest
THREAD_IDX[0],
BLOCK_IDX[0],
BLOCK_IDX[1],
BLOCK_IDX[2]
]
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 4:
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
"using the Blockwise4D GPU thread index mapping."
)
dimensions = ispace.dimensions_in_loop_order()
idx_map: dict[PsSymbol, PsExpression] = dict()
for dim, tid in zip(dimensions[::-1], self._indices_fastest_first):
idx_map[dim.counter] = dim.start + dim.step * PsCast(
deconstify(dim.counter.get_dtype()), tid
)
return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._indices_fastest_first[0]
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs.
Args:
ctx: The kernel creation context
omit_range_check: If `True`, generated index translation code will not check if the point identified
by block and thread indices is actually contained in the iteration space
thread_mapping: Callback object which defines the mapping of thread indices onto iteration space points
"""
def __init__(
self,
ctx: KernelCreationContext,
omit_range_check: bool = False,
thread_mapping: ThreadMapping | None = None,
) -> None:
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping()
)
self._typify = Typifier(ctx)
@property
def required_headers(self) -> set[str]:
return {'"gpu_defines.h"'}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> PsBlock:
if isinstance(ispace, FullIterationSpace):
return self._prepend_dense_translation(body, ispace)
elif isinstance(ispace, SparseIterationSpace):
return self._prepend_sparse_translation(body, ispace)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction)
func = call.function.func
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType):
match func:
case (
MathFunctions.Exp
| MathFunctions.Log
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Sqrt
| MathFunctions.Ceil
| MathFunctions.Floor
) if dtype.width in (16, 32, 64):
prefix = "h" if dtype.width == 16 else ""
suffix = "f" if dtype.width == 32 else ""
name = f"{prefix}{func.function_name}{suffix}"
cfunc = CFunction(name, arg_types, dtype)
case (
MathFunctions.Pow
| MathFunctions.Tan
| MathFunctions.Sinh
| MathFunctions.Cosh
| MathFunctions.ASin
| MathFunctions.ACos
| MathFunctions.ATan
| MathFunctions.ATan2
) if dtype.width in (32, 64):
# These are unavailable for fp16
suffix = "f" if dtype.width == 32 else ""
name = f"{func.function_name}{suffix}"
cfunc = CFunction(name, arg_types, dtype)
case (
MathFunctions.Min | MathFunctions.Max | MathFunctions.Abs
) if dtype.width in (32, 64):
suffix = "f" if dtype.width == 32 else ""
name = f"f{func.function_name}{suffix}"
cfunc = CFunction(name, arg_types, dtype)
case MathFunctions.Abs if dtype.width == 16:
cfunc = CFunction(" __habs", arg_types, dtype)
case _:
raise MaterializationError(
f"Cannot materialize call to function {func}"
)
call.function = cfunc
return call
raise MaterializationError(
f"No implementation available for function {func} on data type {dtype}"
)
# Internals
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
conds = []
dimensions = ispace.dimensions_in_loop_order()
for dim in dimensions:
# counter declarations must be ordered slowest-to-fastest
# such that inner dimensions can depend on outer ones
dim.counter.dtype = constify(dim.counter.get_dtype())
ctr_expr = PsExpression.make(dim.counter)
indexing_decls.append(
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
)
if not self._omit_range_check:
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
return ast
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
) -> PsBlock:
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
ctr_mapping = self._thread_mapping(ispace)
sparse_idx_decl = self._typify(
PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter])
)
mappings = [
PsDeclaration(
PsExpression.make(ctr),
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr_expr.clone(), factory.parse_index(0)),
),
coord.name,
),
)
for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members)
]
body.statements = mappings + body.statements
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
ast = body
return ast