diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 589841db87efb598ffeed20d4d11db7ffcd452cc..3b602964bf2442823c9965f3398517635c8cc217 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -2,6 +2,7 @@ from .platform import Platform from .generic_cpu import GenericCpu, GenericVectorCpu from .generic_gpu import GenericGpu from .cuda import CudaPlatform +from .hip import HipPlatform from .x86 import X86VectorCpu, X86VectorArch from .sycl import SyclPlatform @@ -13,5 +14,6 @@ __all__ = [ "X86VectorArch", "GenericGpu", "CudaPlatform", + "HipPlatform", "SyclPlatform", ] diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index d630594ba8a633e04e2b058e74f77d604f6d95a9..98ff3e3d332a46074931514ba3af1603dc6318b2 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -1,335 +1,11 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from ...types import constify, deconstify -from ..exceptions import MaterializationError from .generic_gpu import GenericGpu -from ..memory import PsSymbol -from ..kernelcreation import ( - Typifier, - IterationSpace, - FullIterationSpace, - SparseIterationSpace, - AstFactory, -) - -from ..kernelcreation.context import KernelCreationContext -from ..ast.structural import PsBlock, PsConditional, PsDeclaration -from ..ast.expressions import ( - PsExpression, - PsLiteralExpr, - PsCast, - PsCall, - PsLookup, - PsBufferAcc, -) -from ..ast.expressions import PsLt, PsAnd -from ...types import PsSignedIntegerType, PsIeeeFloatType -from ..literals import PsLiteral -from ..functions import PsMathFunction, MathFunctions, CFunction - - -int32 = PsSignedIntegerType(width=32, const=False) - -BLOCK_IDX = [ - PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") -] -THREAD_IDX = [ - PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") -] -BLOCK_DIM = [ - PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") -] -GRID_DIM = [ - PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") -] - - -class ThreadMapping(ABC): - - @abstractmethod - def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: - """Map the current thread index onto a point in the given iteration space. - - Implementations of this method must return a declaration for each dimension counter - of the given iteration space. - """ - - -class Linear3DMapping(ThreadMapping): - """3D globally linearized mapping, where each thread is assigned a work item according to - its location in the global launch grid.""" - - def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: - match ispace: - case FullIterationSpace(): - return self._dense_mapping(ispace) - case SparseIterationSpace(): - return self._sparse_mapping(ispace) - case _: - assert False, "unexpected iteration space" - - def _dense_mapping( - self, ispace: FullIterationSpace - ) -> dict[PsSymbol, PsExpression]: - if ispace.rank > 3: - raise MaterializationError( - f"Cannot handle {ispace.rank}-dimensional iteration space " - "using the Linear3D GPU thread index mapping." - ) - - dimensions = ispace.dimensions_in_loop_order() - idx_map: dict[PsSymbol, PsExpression] = dict() - - for coord, dim in enumerate(dimensions[::-1]): - tid = self._linear_thread_idx(coord) - idx_map[dim.counter] = dim.start + dim.step * PsCast( - deconstify(dim.counter.get_dtype()), tid - ) - - return idx_map - - def _sparse_mapping( - self, ispace: SparseIterationSpace - ) -> dict[PsSymbol, PsExpression]: - sparse_ctr = PsExpression.make(ispace.sparse_counter) - thread_idx = self._linear_thread_idx(0) - idx_map: dict[PsSymbol, PsExpression] = { - ispace.sparse_counter: PsCast( - deconstify(sparse_ctr.get_dtype()), thread_idx - ) - } - return idx_map - - def _linear_thread_idx(self, coord: int): - block_size = BLOCK_DIM[coord] - block_idx = BLOCK_IDX[coord] - thread_idx = THREAD_IDX[coord] - return block_idx * block_size + thread_idx - - -class Blockwise4DMapping(ThreadMapping): - """Blockwise index mapping for up to 4D iteration spaces, where the outer three dimensions - are mapped to block indices.""" - - _indices_fastest_first = [ # slowest to fastest - THREAD_IDX[0], - BLOCK_IDX[0], - BLOCK_IDX[1], - BLOCK_IDX[2] - ] - - def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: - match ispace: - case FullIterationSpace(): - return self._dense_mapping(ispace) - case SparseIterationSpace(): - return self._sparse_mapping(ispace) - case _: - assert False, "unexpected iteration space" - - def _dense_mapping( - self, ispace: FullIterationSpace - ) -> dict[PsSymbol, PsExpression]: - if ispace.rank > 4: - raise MaterializationError( - f"Cannot handle {ispace.rank}-dimensional iteration space " - "using the Blockwise4D GPU thread index mapping." - ) - - dimensions = ispace.dimensions_in_loop_order() - idx_map: dict[PsSymbol, PsExpression] = dict() - - for dim, tid in zip(dimensions[::-1], self._indices_fastest_first): - idx_map[dim.counter] = dim.start + dim.step * PsCast( - deconstify(dim.counter.get_dtype()), tid - ) - - return idx_map - - def _sparse_mapping( - self, ispace: SparseIterationSpace - ) -> dict[PsSymbol, PsExpression]: - sparse_ctr = PsExpression.make(ispace.sparse_counter) - thread_idx = self._indices_fastest_first[0] - idx_map: dict[PsSymbol, PsExpression] = { - ispace.sparse_counter: PsCast( - deconstify(sparse_ctr.get_dtype()), thread_idx - ) - } - return idx_map - class CudaPlatform(GenericGpu): - """Platform for CUDA-based GPUs. - - Args: - ctx: The kernel creation context - omit_range_check: If `True`, generated index translation code will not check if the point identified - by block and thread indices is actually contained in the iteration space - thread_mapping: Callback object which defines the mapping of thread indices onto iteration space points - """ - - def __init__( - self, - ctx: KernelCreationContext, - omit_range_check: bool = False, - thread_mapping: ThreadMapping | None = None, - ) -> None: - super().__init__(ctx) - - self._omit_range_check = omit_range_check - self._thread_mapping = ( - thread_mapping if thread_mapping is not None else Linear3DMapping() - ) - - self._typify = Typifier(ctx) + """Platform for the CUDA GPU taret.""" @property def required_headers(self) -> set[str]: - return {'"pystencils_runtime/hip.h"'} # TODO: move to HipPlatform once it is introduced - - def materialize_iteration_space( - self, body: PsBlock, ispace: IterationSpace - ) -> PsBlock: - if isinstance(ispace, FullIterationSpace): - return self._prepend_dense_translation(body, ispace) - elif isinstance(ispace, SparseIterationSpace): - return self._prepend_sparse_translation(body, ispace) - else: - raise MaterializationError(f"Unknown type of iteration space: {ispace}") - - def select_function(self, call: PsCall) -> PsExpression: - assert isinstance(call.function, PsMathFunction) - - func = call.function.func - dtype = call.get_dtype() - arg_types = (dtype,) * func.num_args - - if isinstance(dtype, PsIeeeFloatType): - match func: - case ( - MathFunctions.Exp - | MathFunctions.Log - | MathFunctions.Sin - | MathFunctions.Cos - | MathFunctions.Sqrt - | MathFunctions.Ceil - | MathFunctions.Floor - ) if dtype.width in (16, 32, 64): - prefix = "h" if dtype.width == 16 else "" - suffix = "f" if dtype.width == 32 else "" - name = f"{prefix}{func.function_name}{suffix}" - cfunc = CFunction(name, arg_types, dtype) - - case ( - MathFunctions.Pow - | MathFunctions.Tan - | MathFunctions.Sinh - | MathFunctions.Cosh - | MathFunctions.ASin - | MathFunctions.ACos - | MathFunctions.ATan - | MathFunctions.ATan2 - ) if dtype.width in (32, 64): - # These are unavailable for fp16 - suffix = "f" if dtype.width == 32 else "" - name = f"{func.function_name}{suffix}" - cfunc = CFunction(name, arg_types, dtype) - - case ( - MathFunctions.Min | MathFunctions.Max | MathFunctions.Abs - ) if dtype.width in (32, 64): - suffix = "f" if dtype.width == 32 else "" - name = f"f{func.function_name}{suffix}" - cfunc = CFunction(name, arg_types, dtype) - - case MathFunctions.Abs if dtype.width == 16: - cfunc = CFunction(" __habs", arg_types, dtype) - - case _: - raise MaterializationError( - f"Cannot materialize call to function {func}" - ) - - call.function = cfunc - return call - - raise MaterializationError( - f"No implementation available for function {func} on data type {dtype}" - ) - - # Internals - - def _prepend_dense_translation( - self, body: PsBlock, ispace: FullIterationSpace - ) -> PsBlock: - ctr_mapping = self._thread_mapping(ispace) - - indexing_decls = [] - conds = [] - - dimensions = ispace.dimensions_in_loop_order() - - for dim in dimensions: - # counter declarations must be ordered slowest-to-fastest - # such that inner dimensions can depend on outer ones - - dim.counter.dtype = constify(dim.counter.get_dtype()) - - ctr_expr = PsExpression.make(dim.counter) - indexing_decls.append( - self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter])) - ) - if not self._omit_range_check: - conds.append(PsLt(ctr_expr, dim.stop)) - - if conds: - condition: PsExpression = conds[0] - for cond in conds[1:]: - condition = PsAnd(condition, cond) - ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) - else: - body.statements = indexing_decls + body.statements - ast = body - - return ast - - def _prepend_sparse_translation( - self, body: PsBlock, ispace: SparseIterationSpace - ) -> PsBlock: - factory = AstFactory(self._ctx) - ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) - - sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) - ctr_mapping = self._thread_mapping(ispace) - - sparse_idx_decl = self._typify( - PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter]) - ) - - mappings = [ - PsDeclaration( - PsExpression.make(ctr), - PsLookup( - PsBufferAcc( - ispace.index_list.base_pointer, - (sparse_ctr_expr.clone(), factory.parse_index(0)), - ), - coord.name, - ), - ) - for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members) - ] - body.statements = mappings + body.statements - - if not self._omit_range_check: - stop = PsExpression.make(ispace.index_list.shape[0]) - condition = PsLt(sparse_ctr_expr.clone(), stop) - ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)]) - else: - body.statements = [sparse_idx_decl] + body.statements - ast = body - - return ast + return set() diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index b5b35c8b03447f1d5c35ed1289b89542bb1127ca..fac37ffa5955338c3eac0126eccc3014e3f0c4ba 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,7 +1,331 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from ...types import constify, deconstify +from ..exceptions import MaterializationError from .platform import Platform +from ..memory import PsSymbol +from ..kernelcreation import ( + Typifier, + IterationSpace, + FullIterationSpace, + SparseIterationSpace, + AstFactory, +) + +from ..kernelcreation.context import KernelCreationContext +from ..ast.structural import PsBlock, PsConditional, PsDeclaration +from ..ast.expressions import ( + PsExpression, + PsLiteralExpr, + PsCast, + PsCall, + PsLookup, + PsBufferAcc, +) +from ..ast.expressions import PsLt, PsAnd +from ...types import PsSignedIntegerType, PsIeeeFloatType +from ..literals import PsLiteral +from ..functions import PsMathFunction, MathFunctions, CFunction + + +int32 = PsSignedIntegerType(width=32, const=False) + +BLOCK_IDX = [ + PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") +] +THREAD_IDX = [ + PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") +] +BLOCK_DIM = [ + PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") +] +GRID_DIM = [ + PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") +] + + +class ThreadMapping(ABC): + + @abstractmethod + def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: + """Map the current thread index onto a point in the given iteration space. + + Implementations of this method must return a declaration for each dimension counter + of the given iteration space. + """ + + +class Linear3DMapping(ThreadMapping): + """3D globally linearized mapping, where each thread is assigned a work item according to + its location in the global launch grid.""" + + def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: + match ispace: + case FullIterationSpace(): + return self._dense_mapping(ispace) + case SparseIterationSpace(): + return self._sparse_mapping(ispace) + case _: + assert False, "unexpected iteration space" + + def _dense_mapping( + self, ispace: FullIterationSpace + ) -> dict[PsSymbol, PsExpression]: + if ispace.rank > 3: + raise MaterializationError( + f"Cannot handle {ispace.rank}-dimensional iteration space " + "using the Linear3D GPU thread index mapping." + ) + + dimensions = ispace.dimensions_in_loop_order() + idx_map: dict[PsSymbol, PsExpression] = dict() + + for coord, dim in enumerate(dimensions[::-1]): + tid = self._linear_thread_idx(coord) + idx_map[dim.counter] = dim.start + dim.step * PsCast( + deconstify(dim.counter.get_dtype()), tid + ) + + return idx_map + + def _sparse_mapping( + self, ispace: SparseIterationSpace + ) -> dict[PsSymbol, PsExpression]: + sparse_ctr = PsExpression.make(ispace.sparse_counter) + thread_idx = self._linear_thread_idx(0) + idx_map: dict[PsSymbol, PsExpression] = { + ispace.sparse_counter: PsCast( + deconstify(sparse_ctr.get_dtype()), thread_idx + ) + } + return idx_map + + def _linear_thread_idx(self, coord: int): + block_size = BLOCK_DIM[coord] + block_idx = BLOCK_IDX[coord] + thread_idx = THREAD_IDX[coord] + return block_idx * block_size + thread_idx + + +class Blockwise4DMapping(ThreadMapping): + """Blockwise index mapping for up to 4D iteration spaces, where the outer three dimensions + are mapped to block indices.""" + + _indices_fastest_first = [ # slowest to fastest + THREAD_IDX[0], + BLOCK_IDX[0], + BLOCK_IDX[1], + BLOCK_IDX[2] + ] + + def __call__(self, ispace: IterationSpace) -> dict[PsSymbol, PsExpression]: + match ispace: + case FullIterationSpace(): + return self._dense_mapping(ispace) + case SparseIterationSpace(): + return self._sparse_mapping(ispace) + case _: + assert False, "unexpected iteration space" + + def _dense_mapping( + self, ispace: FullIterationSpace + ) -> dict[PsSymbol, PsExpression]: + if ispace.rank > 4: + raise MaterializationError( + f"Cannot handle {ispace.rank}-dimensional iteration space " + "using the Blockwise4D GPU thread index mapping." + ) + + dimensions = ispace.dimensions_in_loop_order() + idx_map: dict[PsSymbol, PsExpression] = dict() + + for dim, tid in zip(dimensions[::-1], self._indices_fastest_first): + idx_map[dim.counter] = dim.start + dim.step * PsCast( + deconstify(dim.counter.get_dtype()), tid + ) + + return idx_map + + def _sparse_mapping( + self, ispace: SparseIterationSpace + ) -> dict[PsSymbol, PsExpression]: + sparse_ctr = PsExpression.make(ispace.sparse_counter) + thread_idx = self._indices_fastest_first[0] + idx_map: dict[PsSymbol, PsExpression] = { + ispace.sparse_counter: PsCast( + deconstify(sparse_ctr.get_dtype()), thread_idx + ) + } + return idx_map + class GenericGpu(Platform): - """Base class for GPU platforms.""" + """Common base platform for CUDA- and HIP-type GPU targets. + + Args: + ctx: The kernel creation context + omit_range_check: If `True`, generated index translation code will not check if the point identified + by block and thread indices is actually contained in the iteration space + thread_mapping: Callback object which defines the mapping of thread indices onto iteration space points + """ + + def __init__( + self, + ctx: KernelCreationContext, + omit_range_check: bool = False, + thread_mapping: ThreadMapping | None = None, + ) -> None: + super().__init__(ctx) + + self._omit_range_check = omit_range_check + self._thread_mapping = ( + thread_mapping if thread_mapping is not None else Linear3DMapping() + ) + + self._typify = Typifier(ctx) + + def materialize_iteration_space( + self, body: PsBlock, ispace: IterationSpace + ) -> PsBlock: + if isinstance(ispace, FullIterationSpace): + return self._prepend_dense_translation(body, ispace) + elif isinstance(ispace, SparseIterationSpace): + return self._prepend_sparse_translation(body, ispace) + else: + raise MaterializationError(f"Unknown type of iteration space: {ispace}") + + def select_function(self, call: PsCall) -> PsExpression: + assert isinstance(call.function, PsMathFunction) + + func = call.function.func + dtype = call.get_dtype() + arg_types = (dtype,) * func.num_args + + if isinstance(dtype, PsIeeeFloatType): + match func: + case ( + MathFunctions.Exp + | MathFunctions.Log + | MathFunctions.Sin + | MathFunctions.Cos + | MathFunctions.Sqrt + | MathFunctions.Ceil + | MathFunctions.Floor + ) if dtype.width in (16, 32, 64): + prefix = "h" if dtype.width == 16 else "" + suffix = "f" if dtype.width == 32 else "" + name = f"{prefix}{func.function_name}{suffix}" + cfunc = CFunction(name, arg_types, dtype) + + case ( + MathFunctions.Pow + | MathFunctions.Tan + | MathFunctions.Sinh + | MathFunctions.Cosh + | MathFunctions.ASin + | MathFunctions.ACos + | MathFunctions.ATan + | MathFunctions.ATan2 + ) if dtype.width in (32, 64): + # These are unavailable for fp16 + suffix = "f" if dtype.width == 32 else "" + name = f"{func.function_name}{suffix}" + cfunc = CFunction(name, arg_types, dtype) + + case ( + MathFunctions.Min | MathFunctions.Max | MathFunctions.Abs + ) if dtype.width in (32, 64): + suffix = "f" if dtype.width == 32 else "" + name = f"f{func.function_name}{suffix}" + cfunc = CFunction(name, arg_types, dtype) + + case MathFunctions.Abs if dtype.width == 16: + cfunc = CFunction(" __habs", arg_types, dtype) + + case _: + raise MaterializationError( + f"Cannot materialize call to function {func}" + ) + + call.function = cfunc + return call + + raise MaterializationError( + f"No implementation available for function {func} on data type {dtype}" + ) + + # Internals + + def _prepend_dense_translation( + self, body: PsBlock, ispace: FullIterationSpace + ) -> PsBlock: + ctr_mapping = self._thread_mapping(ispace) + + indexing_decls = [] + conds = [] + + dimensions = ispace.dimensions_in_loop_order() + + for dim in dimensions: + # counter declarations must be ordered slowest-to-fastest + # such that inner dimensions can depend on outer ones + + dim.counter.dtype = constify(dim.counter.get_dtype()) + + ctr_expr = PsExpression.make(dim.counter) + indexing_decls.append( + self._typify(PsDeclaration(ctr_expr, ctr_mapping[dim.counter])) + ) + if not self._omit_range_check: + conds.append(PsLt(ctr_expr, dim.stop)) + + if conds: + condition: PsExpression = conds[0] + for cond in conds[1:]: + condition = PsAnd(condition, cond) + ast = PsBlock(indexing_decls + [PsConditional(condition, body)]) + else: + body.statements = indexing_decls + body.statements + ast = body + + return ast + + def _prepend_sparse_translation( + self, body: PsBlock, ispace: SparseIterationSpace + ) -> PsBlock: + factory = AstFactory(self._ctx) + ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) + + sparse_ctr_expr = PsExpression.make(ispace.sparse_counter) + ctr_mapping = self._thread_mapping(ispace) + + sparse_idx_decl = self._typify( + PsDeclaration(sparse_ctr_expr, ctr_mapping[ispace.sparse_counter]) + ) + + mappings = [ + PsDeclaration( + PsExpression.make(ctr), + PsLookup( + PsBufferAcc( + ispace.index_list.base_pointer, + (sparse_ctr_expr.clone(), factory.parse_index(0)), + ), + coord.name, + ), + ) + for ctr, coord in zip(ispace.spatial_indices, ispace.coordinate_members) + ] + body.statements = mappings + body.statements + + if not self._omit_range_check: + stop = PsExpression.make(ispace.index_list.shape[0]) + condition = PsLt(sparse_ctr_expr.clone(), stop) + ast = PsBlock([sparse_idx_decl, PsConditional(condition, body)]) + else: + body.statements = [sparse_idx_decl] + body.statements + ast = body + + return ast diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py new file mode 100644 index 0000000000000000000000000000000000000000..c758995a0d9f8fbbb2e9e424bf2cfa6ab7eca086 --- /dev/null +++ b/src/pystencils/backend/platforms/hip.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .generic_gpu import GenericGpu + + +class HipPlatform(GenericGpu): + """Platform for the HIP GPU taret.""" + + @property + def required_headers(self) -> set[str]: + return {'"pystencils_runtime/hip.h"'} diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index f53f1b9b8abb3f13c3a0a8aa07a73c06e81c72db..9e3c8b163b8b6425ccc244107c3fd1eb83fc4df3 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -440,21 +440,27 @@ class DefaultKernelCreationDriver: gpu_opts = self._cfg.gpu omit_range_check: bool = gpu_opts.get_option("omit_range_check") + thread_mapping = ( + self._gpu_indexing.get_thread_mapping() + if self._gpu_indexing is not None + else None + ) + + GpuPlatform: type match self._target: - case Target.CUDA | Target.HIP: - from ..backend.platforms import CudaPlatform - - thread_mapping = ( - self._gpu_indexing.get_thread_mapping() - if self._gpu_indexing is not None - else None - ) - - return CudaPlatform( - self._ctx, - omit_range_check=omit_range_check, - thread_mapping=thread_mapping, - ) + case Target.CUDA: + from ..backend.platforms import CudaPlatform as GpuPlatform + case Target.HIP: + from ..backend.platforms import HipPlatform as GpuPlatform + case _: + assert False, f"unexpected GPU target: {self._target}" + + return GpuPlatform( + self._ctx, + omit_range_check=omit_range_check, + thread_mapping=thread_mapping, + ) + elif self._target == Target.SYCL: from ..backend.platforms import SyclPlatform diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py index 27d6fc817d5a9193c3faa4b170d907987fe6022e..890afb0fb58d5af4a7eaade9c8c0222c20aca604 100644 --- a/src/pystencils/codegen/gpu_indexing.py +++ b/src/pystencils/codegen/gpu_indexing.py @@ -14,7 +14,7 @@ from ..backend.kernelcreation import ( FullIterationSpace, SparseIterationSpace, ) -from ..backend.platforms.cuda import ThreadMapping +from ..backend.platforms.generic_gpu import ThreadMapping from ..backend.ast.expressions import PsExpression @@ -236,7 +236,7 @@ class GpuIndexing: def get_thread_mapping(self) -> ThreadMapping: """Retrieve a thread mapping object for use by the backend""" - from ..backend.platforms.cuda import Linear3DMapping, Blockwise4DMapping + from ..backend.platforms.generic_gpu import Linear3DMapping, Blockwise4DMapping match self._scheme: case GpuIndexingScheme.Linear3D: