Skip to content
Snippets Groups Projects
Commit d2dd3dfa authored by Frederik Hennig's avatar Frederik Hennig
Browse files

also use ThreadIdxMapping for sparse kernels

parent 3d81f031
No related branches found
No related tags found
1 merge request!449GPU Indexing Schemes and Launch Configurations
......@@ -49,7 +49,7 @@ GRID_DIM = [
class ThreadToIndexMapping(ABC):
@abstractmethod
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
"""Map the current thread index onto a point in the given iteration space.
Implementations of this method must return a declaration for each dimension counter
......@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
"""3D globally linearized mapping, where each thread is assigned a work item according to
its location in the global launch grid."""
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 3:
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
......@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._linear_thread_idx(0)
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord]
......@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
THREAD_IDX[0],
]
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 4:
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
......@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._indices_in_loop_order[-1]
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs."""
......@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._thread_mapping = thread_mapping
self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping()
)
self._typify = Typifier(ctx)
......@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
) -> PsBlock:
dimensions = ispace.dimensions_in_loop_order()
# TODO move to codegen
# if not self._manual_launch_grid:
# try:
# threads_range = self.threads_from_ispace(ispace)
# except MaterializationError as e:
# warn(
# str(e.args[0])
# + "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.",
# UserWarning,
# )
# threads_range = None
# else:
# threads_range = None
idx_mapper = (
self._thread_mapping
if self._thread_mapping is not None
else Linear3DMapping()
)
ctr_mapping = idx_mapper(ispace)
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
conds = []
......@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = BLOCK_IDX[0] * BLOCK_DIM[0] + THREAD_IDX[0]
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
ctr_mapping = self._thread_mapping(ispace)
sparse_idx_decl = self._typify(
PsDeclaration(sparse_ctr, PsCast(sparse_ctr.get_dtype(), thread_idx))
PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter])
)
mappings = [
......@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr, factory.parse_index(0)),
(sparse_ctr_expr.clone(), factory.parse_index(0)),
),
coord.name,
),
......@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr, stop)
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment