diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index fb613347a94f1ff479f6cbf85e9e0c1fb540d1db..eff88df7e2dcb317403ff417bd996f5e8acb09c0 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -315,13 +315,47 @@ class CudaPlatform(GenericGpu): # Internals + # TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU? + + def _get_condition_for_translation( + self, ispace: IterationSpace): + + if not self._omit_range_check: + return None + + match ispace: + case FullIterationSpace(): + + dimensions = ispace.dimensions_in_loop_order() + + conds = [] + for dim in dimensions: + ctr_expr = PsExpression.make(dim.counter) + conds.append(PsLt(ctr_expr, dim.stop)) + + if conds: + condition: PsExpression = conds[0] + for cond in conds[1:]: + condition = PsAnd(condition, cond) + return condition + else: + return None + + case SparseIterationSpace(): + sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) + stop = PsExpression.make(ispace.index_list.shape[0]) + + return PsLt(sparse_ctr_expr.clone(), stop) + case _: + assert False, "Unknown iteration space" + def _prepend_dense_translation( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: ctr_mapping = self._thread_mapping(ispace) indexing_decls = [] - conds = [] + cond = self._get_condition_for_translation(ispace) dimensions = ispace.dimensions_in_loop_order() @@ -335,14 +369,9 @@ class CudaPlatform(GenericGpu): indexing_decls.append( self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter])) ) - if not self._omit_range_check: - conds.append(PsLt(ctr_expr, dim.stop)) - - if conds: - condition: PsExpression = conds[0] - for cond in conds[1:]: - condition = PsAnd(condition, cond) - ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) + + if cond: + ast = PsBlock(indexing_decls + [PsConditional(cond, body)]) else: body.statements = indexing_decls + body.statements ast = body @@ -355,6 +384,8 @@ class CudaPlatform(GenericGpu): factory = AstFactory(self._ctx) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) + cond = self._get_condition_for_translation(ispace) + sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) ctr_mapping = self._thread_mapping(ispace) @@ -377,10 +408,8 @@ class CudaPlatform(GenericGpu): ] body.statements = mappings + body.statements - if not self._omit_range_check: - stop = PsExpression.make(ispace.index_list.shape[0]) - condition = PsLt(sparse_ctr_expr.clone(), stop) - ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)]) + if cond: + ast = PsBlock([sparse_idx_decl, PsConditional(cond, body)]) else: body.statements = [sparse_idx_decl] + body.statements ast = body