Skip to content
Snippets Groups Projects
Commit a972d759 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix getter for thread exec condition for dense/sparse iteration spaces in cuda.py

parent 02be4d5e
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #74694 failed
......@@ -172,7 +172,7 @@ class Blockwise4DMapping(ThreadMapping):
class CudaPlatform(GenericGpu):
"""Platform for CUDA-based GPUs.
Args:
ctx: The kernel creation context
omit_range_check: If `True`, generated index translation code will not check if the point identified
......@@ -209,6 +209,33 @@ class CudaPlatform(GenericGpu):
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def _get_condition_for_translation(self, ispace: IterationSpace):
if self._omit_range_check:
return None
if isinstance(ispace, FullIterationSpace):
conds = []
dimensions = ispace.dimensions_in_loop_order()
for dim in dimensions:
ctr_expr = PsExpression.make(dim.counter)
conds.append(PsLt(ctr_expr, dim.stop))
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return condition
elif isinstance(ispace, SparseIterationSpace):
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
stop = PsExpression.make(ispace.index_list.shape[0])
return PsLt(sparse_ctr_expr.clone(), stop)
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, PsReductionFunction | PsMathFunction)
......@@ -341,47 +368,12 @@ class CudaPlatform(GenericGpu):
# Internals
# TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU?
def _get_condition_for_translation(
self, ispace: IterationSpace):
if self._omit_range_check:
return None
match ispace:
case FullIterationSpace():
dimensions = ispace.dimensions_in_loop_order()
conds = []
for dim in dimensions:
ctr_expr = PsExpression.make(dim.counter)
conds.append(PsLt(ctr_expr, dim.stop))
if conds:
condition: PsExpression = conds[0]
for cond in conds[1:]:
condition = PsAnd(condition, cond)
return condition
else:
return None
case SparseIterationSpace():
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
stop = PsExpression.make(ispace.index_list.shape[0])
return PsLt(sparse_ctr_expr.clone(), stop)
case _:
assert False, "Unknown iteration space"
def _prepend_dense_translation(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
ctr_mapping = self._thread_mapping(ispace)
indexing_decls = []
cond = self._get_condition_for_translation(ispace)
dimensions = ispace.dimensions_in_loop_order()
......@@ -396,6 +388,7 @@ class CudaPlatform(GenericGpu):
self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter]))
)
cond = self._get_condition_for_translation(ispace)
if cond:
ast = PsBlock(indexing_decls + [PsConditional(cond, body)])
else:
......@@ -410,8 +403,6 @@ class CudaPlatform(GenericGpu):
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
cond = self._get_condition_for_translation(ispace)
sparse_ctr_expr = PsExpression.make(ispace.sparse_counter)
ctr_mapping = self._thread_mapping(ispace)
......@@ -434,6 +425,7 @@ class CudaPlatform(GenericGpu):
]
body.statements = mappings + body.statements
cond = self._get_condition_for_translation(ispace)
if cond:
ast = PsBlock([sparse_idx_decl, PsConditional(cond, body)])
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment