Skip to content
Snippets Groups Projects
Commit d2dd3dfa authored by Frederik Hennig's avatar Frederik Hennig
Browse files

also use ThreadIdxMapping for sparse kernels

parent 3d81f031
Branches
Tags release/0.3.2
1 merge request!449GPU Indexing Schemes and Launch Configurations
......@@ -49,7 +49,7 @@ GRID_DIM = [
class ThreadToIndexMapping(ABC):
@abstractmethod
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
"""Map the current thread index onto a point in the given iteration space.
Implementations of this method must return a declaration for each dimension counter
......@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
"""3D globally linearized mapping, where each thread is assigned a work item according to
its location in the global launch grid."""
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 3:
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
......@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._linear_thread_idx(0)
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord]
......@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
THREAD_IDX[0],
]
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 4:
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
......@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._indices_in_loop_order[-1]
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs."""
......@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._thread_mapping = thread_mapping
self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping()
)
self._typify = Typifier(ctx)
......@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
) -> PsBlock:
dimensions = ispace.dimensions_in_loop_order()
# TODO move to codegen
# if not self._manual_launch_grid:
# try:
# threads_range = self.threads_from_ispace(ispace)
# except MaterializationError as e:
# warn(
# str(e.args[0])
# + "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.",
# UserWarning,
# )
# threads_range = None
# else:
# threads_range = None
idx_mapper = (
self._thread_mapping
if self._thread_mapping is not None
else Linear3DMapping()
)
ctr_mapping = idx_mapper(ispace)
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
conds = []
......@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = BLOCK_IDX[0] * BLOCK_DIM[0] + THREAD_IDX[0]
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
ctr_mapping = self._thread_mapping(ispace)
sparse_idx_decl = self._typify(
PsDeclaration(sparse_ctr, PsCast(sparse_ctr.get_dtype(), thread_idx))
PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter])
)
mappings = [
......@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr, factory.parse_index(0)),
(sparse_ctr_expr.clone(), factory.parse_index(0)),
),
coord.name,
),
......@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr, stop)
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment