diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index cff6f935f2f89e09e2d0c8323e3d245c2b091c23..122011eb0cefcb07b7cf6f550519fc6777deb6c3 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -49,7 +49,7 @@ GRID_DIM = [ class ThreadToIndexMapping(ABC): @abstractmethod - def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]: + def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: """Map the current thread index onto a point in the given iteration space. Implementations of this method must return a declaration for each dimension counter @@ -61,7 +61,18 @@ class Linear3DMapping(ThreadToIndexMapping): """3D globally linearized mapping, where each thread is assigned a work item according to its location in the global launch grid.""" - def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]: + def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: + match ispace: + case FullIterationSpace(): + return self._dense_mapping(ispace) + case SparseIterationSpace(): + return self._sparse_mapping(ispace) + case _: + assert False, "unexpected iteration space" + + def _dense_mapping( + self, ispace: FullIterationSpace + ) -> dict[PsSymbol, PsExpression]: if ispace.rank > 3: raise MaterializationError( f"Cannot handle {ispace.rank}-dimensional iteration space " @@ -79,6 +90,18 @@ class Linear3DMapping(ThreadToIndexMapping): return idx_map + def _sparse_mapping( + self, ispace: SparseIterationSpace + ) -> dict[PsSymbol, PsExpression]: + sparse_ctr = PsExpression.make(ispace.sparse_counter) + thread_idx = self._linear_thread_idx(0) + idx_map: dict[PsSymbol, PsExpression] = { + ispace.sparse_counter: PsCast( + deconstify(sparse_ctr.get_dtype()), thread_idx + ) + } + return idx_map + def _linear_thread_idx(self, coord: int): block_size = BLOCK_DIM[coord] block_idx = BLOCK_IDX[coord] @@ -97,7 +120,18 @@ class Blockwise4DMapping(ThreadToIndexMapping): THREAD_IDX[0], ] - def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]: + def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: + match ispace: + case FullIterationSpace(): + return self._dense_mapping(ispace) + case SparseIterationSpace(): + return self._sparse_mapping(ispace) + case _: + assert False, "unexpected iteration space" + + def _dense_mapping( + self, ispace: FullIterationSpace + ) -> dict[PsSymbol, PsExpression]: if ispace.rank > 4: raise MaterializationError( f"Cannot handle {ispace.rank}-dimensional iteration space " @@ -114,6 +148,18 @@ class Blockwise4DMapping(ThreadToIndexMapping): return idx_map + def _sparse_mapping( + self, ispace: SparseIterationSpace + ) -> dict[PsSymbol, PsExpression]: + sparse_ctr = PsExpression.make(ispace.sparse_counter) + thread_idx = self._indices_in_loop_order[-1] + idx_map: dict[PsSymbol, PsExpression] = { + ispace.sparse_counter: PsCast( + deconstify(sparse_ctr.get_dtype()), thread_idx + ) + } + return idx_map + class CudaPlatform(GenericGpu): """Platform for CUDA-based GPUs.""" @@ -127,7 +173,9 @@ class CudaPlatform(GenericGpu): super().__init__(ctx) self._omit_range_check = omit_range_check - self._thread_mapping = thread_mapping + self._thread_mapping = ( + thread_mapping if thread_mapping is not None else Linear3DMapping() + ) self._typify = Typifier(ctx) @@ -212,26 +260,7 @@ class CudaPlatform(GenericGpu): ) -> PsBlock: dimensions = ispace.dimensions_in_loop_order() - # TODO move to codegen - # if not self._manual_launch_grid: - # try: - # threads_range = self.threads_from_ispace(ispace) - # except MaterializationError as e: - # warn( - # str(e.args[0]) - # + "\nIf this is intended, set `manual_launch_grid=True` in the code generator configuration.", - # UserWarning, - # ) - # threads_range = None - # else: - # threads_range = None - - idx_mapper = ( - self._thread_mapping - if self._thread_mapping is not None - else Linear3DMapping() - ) - ctr_mapping = idx_mapper(ispace) + ctr_mapping = self._thread_mapping(ispace) indexing_decls = [] conds = [] @@ -264,10 +293,11 @@ class CudaPlatform(GenericGpu): factory = AstFactory(self._ctx) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) - sparse_ctr = PsExpression.make(ispace.sparse_counter) - thread_idx = BLOCK_IDX[0] * BLOCK_DIM[0] + THREAD_IDX[0] + sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) + ctr_mapping = self._thread_mapping(ispace) + sparse_idx_decl = self._typify( - PsDeclaration(sparse_ctr, PsCast(sparse_ctr.get_dtype(), thread_idx)) + PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter]) ) mappings = [ @@ -276,7 +306,7 @@ class CudaPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - (sparse_ctr, factory.parse_index(0)), + (sparse_ctr_expr.clone(), factory.parse_index(0)), ), coord.name, ), @@ -287,7 +317,7 @@ class CudaPlatform(GenericGpu): if not self._omit_range_check: stop = PsExpression.make(ispace.index_list.shape[0]) - condition = PsLt(sparse_ctr, stop) + condition = PsLt(sparse_ctr_expr.clone(), stop) ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)]) else: body.statements = [sparse_idx_decl] + body.statements