Skip to content
Snippets Groups Projects

Optimization for GPU block size determination

Compare and
6 files
+ 148
58
Compare changes
  • Side-by-side
  • Inline

Files

@@ -166,20 +166,16 @@ class CudaPlatform(GenericGpu):
Args:
ctx: The kernel creation context
omit_range_check: If `True`, generated index translation code will not check if the point identified
by block and thread indices is actually contained in the iteration space
thread_mapping: Callback object which defines the mapping of thread indices onto iteration space points
"""
def __init__(
self,
ctx: KernelCreationContext,
omit_range_check: bool = False,
thread_mapping: ThreadMapping | None = None,
) -> None:
super().__init__(ctx)
self._omit_range_check = omit_range_check
self._thread_mapping = (
thread_mapping if thread_mapping is not None else Linear3DMapping()
)
@@ -282,19 +278,12 @@ class CudaPlatform(GenericGpu):
indexing_decls.append(
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
)
if not self._omit_range_check:
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
conds.append(PsLt(ctr_expr, dim.stop))
return ast
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return PsBlock(indexing_decls + [PsConditional(condition, body)])
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
@@ -324,12 +313,6 @@ class CudaPlatform(GenericGpu):
]
body.statements = mappings + body.statements
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
ast = body
return ast
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr_expr.clone(), stop)
return PsBlock([sparse_idx_decl, PsConditional(condition, body)])
Loading