Skip to content
Snippets Groups Projects
Commit d2dd3dfa authored by Frederik Hennig's avatar Frederik Hennig
Browse files

also use ThreadIdxMapping for sparse kernels

parent 3d81f031
No related branches found
No related tags found
1 merge request!449GPU Indexing Schemes and Launch Configurations
...@@ -49,7 +49,7 @@ GRID_DIM = [ ...@@ -49,7 +49,7 @@ GRID_DIM = [
class ThreadToIndexMapping(ABC): class ThreadToIndexMapping(ABC):
@abstractmethod @abstractmethod
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]: def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
"""Map the current thread index onto a point in the given iteration space. """Map the current thread index onto a point in the given iteration space.
Implementations of this method must return a declaration for each dimension counter Implementations of this method must return a declaration for each dimension counter
...@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping): ...@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
"""3D globally linearized mapping, where each thread is assigned a work item according to """3D globally linearized mapping, where each thread is assigned a work item according to
its location in the global launch grid.""" its location in the global launch grid."""
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]: def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 3: if ispace.rank > 3:
raise MaterializationError( raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space " f"Cannot handle {ispace.rank}-dimensional iteration space "
...@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping): ...@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
return idx_map return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._linear_thread_idx(0)
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
def _linear_thread_idx(self, coord: int): def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord] block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord] block_idx = BLOCK_IDX[coord]
...@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping): ...@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
THREAD_IDX[0], THREAD_IDX[0],
] ]
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]: def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
match ispace:
case FullIterationSpace():
return self._dense_mapping(ispace)
case SparseIterationSpace():
return self._sparse_mapping(ispace)
case _:
assert False, "unexpected iteration space"
def _dense_mapping(
self, ispace: FullIterationSpace
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 4: if ispace.rank > 4:
raise MaterializationError( raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space " f"Cannot handle {ispace.rank}-dimensional iteration space "
...@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping): ...@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
return idx_map return idx_map
def _sparse_mapping(
self, ispace: SparseIterationSpace
) -> dict[PsSymbol, PsExpression]:
sparse_ctr = PsExpression.make(ispace.sparse_counter)
thread_idx = self._indices_in_loop_order[-1]
idx_map: dict[PsSymbol, PsExpression] = {
ispace.sparse_counter: PsCast(
deconstify(sparse_ctr.get_dtype()), thread_idx
)
}
return idx_map
class CudaPlatform(GenericGpu): class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs.""" """Platform for CUDA-based GPUs."""
...@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu): ...@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
super().__init__(ctx) super().__init__(ctx)
self._omit_range_check = omit_range_check self._omit_range_check = omit_range_check
self._thread_mapping = thread_mapping self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping()
)
self._typify = Typifier(ctx) self._typify = Typifier(ctx)
...@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu): ...@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
) -> PsBlock: ) -> PsBlock:
dimensions = ispace.dimensions_in_loop_order() dimensions = ispace.dimensions_in_loop_order()
# TODO move to codegen ctr_mapping = self._thread_mapping(ispace)
# if not self._manual_launch_grid:
# try:
# threads_range = self.threads_from_ispace(ispace)
# except MaterializationError as e:
# warn(
# str(e.args[0])
# + "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.",
# UserWarning,
# )
# threads_range = None
# else:
# threads_range = None
idx_mapper = (
self._thread_mapping
if self._thread_mapping is not None
else Linear3DMapping()
)
ctr_mapping = idx_mapper(ispace)
indexing_decls = [] indexing_decls = []
conds = [] conds = []
...@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu): ...@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx) factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr = PsExpression.make(ispace.sparse_counter) sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
thread_idx = BLOCK_IDX[0] * BLOCK_DIM[0] + THREAD_IDX[0] ctr_mapping = self._thread_mapping(ispace)
sparse_idx_decl = self._typify( sparse_idx_decl = self._typify(
PsDeclaration(sparse_ctr, PsCast(sparse_ctr.get_dtype(), thread_idx)) PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter])
) )
mappings = [ mappings = [
...@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu): ...@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
PsLookup( PsLookup(
PsBufferAcc( PsBufferAcc(
ispace.index_list.base_pointer, ispace.index_list.base_pointer,
(sparse_ctr, factory.parse_index(0)), (sparse_ctr_expr.clone(), factory.parse_index(0)),
), ),
coord.name, coord.name,
), ),
...@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu): ...@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
if not self._omit_range_check: if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0]) stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr, stop) condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)]) ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else: else:
body.statements = [sparse_idx_decl] + body.statements body.statements = [sparse_idx_decl] + body.statements
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment