Skip to content
Snippets Groups Projects
Commit ce816539 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Encapsulate fetching of kernel conditions for iteration spaces in separate function

parent d10c65d4
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -315,13 +315,47 @@ class CudaPlatform(GenericGpu):
# Internals
# TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU?
def _get_condition_for_translation(
self, ispace: IterationSpace):
if not self._omit_range_check:
return None
match ispace:
case FullIterationSpace():
dimensions = ispace.dimensions_in_loop_order()
conds = []
for dim in dimensions:
ctr_expr = PsExpression.make(dim.counter)
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return condition
else:
return None
case SparseIterationSpace():
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
stop = PsExpression.make(ispace.index_list.shape[0])
return PsLt(sparse_ctr_expr.clone(), stop)
case _:
assert False, "Unknown iteration space"
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
conds = []
cond = self._get_condition_for_translation(ispace)
dimensions = ispace.dimensions_in_loop_order()
......@@ -335,14 +369,9 @@ class CudaPlatform(GenericGpu):
indexing_decls.append(
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
)
if not self._omit_range_check:
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
ast = PsBlock(indexing_decls + [PsConditional(condition, body)])
if cond:
ast = PsBlock(indexing_decls + [PsConditional(cond, body)])
else:
body.statements = indexing_decls + body.statements
ast = body
......@@ -355,6 +384,8 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
cond = self._get_condition_for_translation(ispace)
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
ctr_mapping = self._thread_mapping(ispace)
......@@ -377,10 +408,8 @@ class CudaPlatform(GenericGpu):
]
body.statements = mappings + body.statements
if not self._omit_range_check:
stop = PsExpression.make(ispace.index_list.shape[0])
condition = PsLt(sparse_ctr_expr.clone(), stop)
ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)])
if cond:
ast = PsBlock([sparse_idx_decl, PsConditional(cond, body)])
else:
body.statements = [sparse_idx_decl] + body.statements
ast = body
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment