Skip to content
Snippets Groups Projects

GPU Indexing Schemes and Launch Configurations

Merged Frederik Hennig requested to merge fhennig/lambdas into v2.0-dev
1 file
+ 59
29
Compare changes
  • Side-by-side
  • Inline
@@ -49,7 +49,7 @@ GRID_DIM = [
@@ -49,7 +49,7 @@ GRID_DIM = [
class ThreadToIndexMapping(ABC):
class ThreadToIndexMapping(ABC):
@abstractmethod
@abstractmethod
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
"""Map the current thread index onto a point in the given iteration space.
"""Map the current thread index onto a point in the given iteration space.
Implementations of this method must return a declaration for each dimension counter
Implementations of this method must return a declaration for each dimension counter
@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
@@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping):
"""3D globally linearized mapping, where each thread is assigned a work item according to
"""3D globally linearized mapping, where each thread is assigned a work item according to
its location in the global launch grid."""
its location in the global launch grid."""
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
 
match ispace:
 
case FullIterationSpace():
 
return self._dense_mapping(ispace)
 
case SparseIterationSpace():
 
return self._sparse_mapping(ispace)
 
case _:
 
assert False, "unexpected iteration space"
 
 
def _dense_mapping(
 
self, ispace: FullIterationSpace
 
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 3:
if ispace.rank > 3:
raise MaterializationError(
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
f"Cannot handle {ispace.rank}-dimensional iteration space "
@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
@@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping):
return idx_map
return idx_map
 
def _sparse_mapping(
 
self, ispace: SparseIterationSpace
 
) -> dict[PsSymbol, PsExpression]:
 
sparse_ctr = PsExpression.make(ispace.sparse_counter)
 
thread_idx = self._linear_thread_idx(0)
 
idx_map: dict[PsSymbol, PsExpression] = {
 
ispace.sparse_counter: PsCast(
 
deconstify(sparse_ctr.get_dtype()), thread_idx
 
)
 
}
 
return idx_map
 
def _linear_thread_idx(self, coord: int):
def _linear_thread_idx(self, coord: int):
block_size = BLOCK_DIM[coord]
block_size = BLOCK_DIM[coord]
block_idx = BLOCK_IDX[coord]
block_idx = BLOCK_IDX[coord]
@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
@@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
THREAD_IDX[0],
THREAD_IDX[0],
]
]
def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]:
 
match ispace:
 
case FullIterationSpace():
 
return self._dense_mapping(ispace)
 
case SparseIterationSpace():
 
return self._sparse_mapping(ispace)
 
case _:
 
assert False, "unexpected iteration space"
 
 
def _dense_mapping(
 
self, ispace: FullIterationSpace
 
) -> dict[PsSymbol, PsExpression]:
if ispace.rank > 4:
if ispace.rank > 4:
raise MaterializationError(
raise MaterializationError(
f"Cannot handle {ispace.rank}-dimensional iteration space "
f"Cannot handle {ispace.rank}-dimensional iteration space "
@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
@@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping):
return idx_map
return idx_map
 
def _sparse_mapping(
 
self, ispace: SparseIterationSpace
 
) -> dict[PsSymbol, PsExpression]:
 
sparse_ctr = PsExpression.make(ispace.sparse_counter)
 
thread_idx = self._indices_in_loop_order[-1]
 
idx_map: dict[PsSymbol, PsExpression] = {
 
ispace.sparse_counter: PsCast(
 
deconstify(sparse_ctr.get_dtype()), thread_idx
 
)
 
}
 
return idx_map
 
class CudaPlatform(GenericGpu):
class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs."""
"""Platform for CUDA-based GPUs."""
@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
@@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu):
super().__init__(ctx)
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._omit_range_check = omit_range_check
self._thread_mapping = thread_mapping
self._thread_mapping = (
 
thread_mapping if thread_mapping is not None else Linear3DMapping()
 
)
self._typify = Typifier(ctx)
self._typify = Typifier(ctx)
@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
@@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu):
) -> PsBlock:
) -> PsBlock:
dimensions = ispace.dimensions_in_loop_order()
dimensions = ispace.dimensions_in_loop_order()
# TODO move to codegen
ctr_mapping = self._thread_mapping(ispace)
# if not self._manual_launch_grid:
# try:
# threads_range = self.threads_from_ispace(ispace)
# except MaterializationError as e:
# warn(
# str(e.args[0])
# + "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.",
# UserWarning,
# )
# threads_range = None
# else:
# threads_range = None
idx_mapper = (
self._thread_mapping
if self._thread_mapping is not None
else Linear3DMapping()
)
ctr_mapping = idx_mapper(ispace)
indexing_decls = []
indexing_decls = []
conds = []
conds = []
@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
@@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx)
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr = PsExpression.make(ispace.sparse_counter)
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
thread_idx = BLOCK_IDX[0] * BLOCK_DIM[0] + THREAD_IDX[0]
ctr_mapping = self._thread_mapping(ispace)
 
sparse_idx_decl = self._typify(
sparse_idx_decl = self._typify(
PsDeclaration(sparse_ctr, PsCast(sparse_ctr.get_dtype(), thread_idx))
PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter])
)
)
mappings = [
mappings = [
@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
@@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu):
PsLookup(
PsLookup(
PsBufferAcc(
PsBufferAcc(
ispace.index_list.base_pointer,
ispace.index_list.base_pointer,
(sparse_ctr, factory.parse_index(0)),
(sparse_ctr_expr.clone(), factory.parse_index(0)),
),
),
coord.name,
coord.name,
),
),
@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
@@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu):
if not self._omit_range_check:
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr, stop)
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
else:
body.statements = [sparse_idx_decl] + body.statements
body.statements = [sparse_idx_decl] + body.statements
Loading