diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 2918588108b2da2305896b39a3dfa5fda4d78593..e67b70db649ec9f96df85e5a93539c27fc85e784 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -172,7 +172,7 @@ class Blockwise4DMapping(ThreadMapping): class CudaPlatform(GenericGpu): """Platform for CUDA-based GPUs. - + Args: ctx: The kernel creation context omit_range_check: If `True`, generated index translation code will not check if the point identified @@ -209,6 +209,33 @@ class CudaPlatform(GenericGpu): else: raise MaterializationError(f"Unknown type of iteration space: {ispace}") + def _get_condition_for_translation(self, ispace: IterationSpace): + if self._omit_range_check: + return None + + if isinstance(ispace, FullIterationSpace): + conds = [] + + dimensions = ispace.dimensions_in_loop_order() + + for dim in dimensions: + ctr_expr = PsExpression.make(dim.counter) + conds.append(PsLt(ctr_expr, dim.stop)) + + condition: PsExpression = conds[0] + for cond in conds[1:]: + condition = PsAnd(condition, cond) + + return condition + elif isinstance(ispace, SparseIterationSpace): + sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) + stop = PsExpression.make(ispace.index_list.shape[0]) + + return PsLt(sparse_ctr_expr.clone(), stop) + else: + raise MaterializationError(f"Unknown type of iteration space: {ispace}") + + def select_function(self, call: PsCall) -> PsExpression | tuple[tuple[PsAstNode, ...], PsAstNode]: call_func = call.function assert isinstance(call_func, PsReductionFunction | PsMathFunction) @@ -341,47 +368,12 @@ class CudaPlatform(GenericGpu): # Internals - # TODO: SYCL platform has very similar code for fetching conditionals -> move to GenericGPU? - - def _get_condition_for_translation( - self, ispace: IterationSpace): - - if self._omit_range_check: - return None - - match ispace: - case FullIterationSpace(): - - dimensions = ispace.dimensions_in_loop_order() - - conds = [] - for dim in dimensions: - ctr_expr = PsExpression.make(dim.counter) - conds.append(PsLt(ctr_expr, dim.stop)) - - if conds: - condition: PsExpression = conds[0] - for cond in conds[1:]: - condition = PsAnd(condition, cond) - return condition - else: - return None - - case SparseIterationSpace(): - sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) - stop = PsExpression.make(ispace.index_list.shape[0]) - - return PsLt(sparse_ctr_expr.clone(), stop) - case _: - assert False, "Unknown iteration space" - def _prepend_dense_translation( self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: ctr_mapping = self._thread_mapping(ispace) indexing_decls = [] - cond = self._get_condition_for_translation(ispace) dimensions = ispace.dimensions_in_loop_order() @@ -396,6 +388,7 @@ class CudaPlatform(GenericGpu): self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter])) ) + cond = self._get_condition_for_translation(ispace) if cond: ast = PsBlock(indexing_decls + [PsConditional(cond, body)]) else: @@ -410,8 +403,6 @@ class CudaPlatform(GenericGpu): factory = AstFactory(self._ctx) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) - cond = self._get_condition_for_translation(ispace) - sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) ctr_mapping = self._thread_mapping(ispace) @@ -434,6 +425,7 @@ class CudaPlatform(GenericGpu): ] body.statements = mappings + body.statements + cond = self._get_condition_for_translation(ispace) if cond: ast = PsBlock([sparse_idx_decl, PsConditional(cond, body)]) else: