diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py index cc4b50e217d68a4bba28c2c95705464a91182212..c76a347f08d682db3097ad365b68c86c0330c042 100644 --- a/src/pystencils/backend/emission/base_printer.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -383,7 +383,7 @@ class BasePrinter(ABC): from ...codegen import GpuKernel sig_parts = [self._func_prefix] if self._func_prefix is not None else [] - if isinstance(func, GpuKernel) and func.target == Target.CUDA: + if isinstance(func, GpuKernel) and func.target.is_gpu(): sig_parts.append("__global__") sig_parts += ["void", func.name, f"({params_str})"] signature = " ".join(sig_parts) diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py index 0d43b40e35b9e90b1f13e2ba7bc279274f466859..821bb7e075ed789f34782d92a4d50ddca973d5ef 100644 --- a/src/pystencils/codegen/config.py +++ b/src/pystencils/codegen/config.py @@ -593,12 +593,14 @@ class CreateKernelConfig(ConfigBase): """Returns either the user-specified JIT compiler, or infers one from the target if none is given.""" jit: JitBase | None = self.get_option("jit") + target = self.get_target() + if jit is None: - if self.get_target().is_cpu(): + if target.is_cpu(): from ..jit import LegacyCpuJit return LegacyCpuJit() - elif self.get_target() == Target.CUDA: + elif target == Target.CUDA: try: from ..jit.gpu_cupy import CupyJit @@ -609,7 +611,7 @@ class CreateKernelConfig(ConfigBase): return no_jit - elif self.get_target() == Target.SYCL: + elif target == Target.SYCL or target == Target.HIP: from ..jit import no_jit return no_jit diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index b8f9c71015765e638ddca22278dfa15c0e5bcaa1..f53f1b9b8abb3f13c3a0a8aa07a73c06e81c72db 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -398,7 +398,7 @@ class DefaultKernelCreationDriver: return kernel_ast def _get_gpu_indexing(self) -> GpuIndexing | None: - if self._target != Target.CUDA: + if not self._target.is_gpu(): return None from .gpu_indexing import dim3 @@ -441,7 +441,7 @@ class DefaultKernelCreationDriver: omit_range_check: bool = gpu_opts.get_option("omit_range_check") match self._target: - case Target.CUDA: + case Target.CUDA | Target.HIP: from ..backend.platforms import CudaPlatform thread_mapping = ( diff --git a/src/pystencils/codegen/target.py b/src/pystencils/codegen/target.py index 0d724b87730f0ec327772bccbb55a8bfff7c8ddd..03364af28951736a02fb105f578985799543d70b 100644 --- a/src/pystencils/codegen/target.py +++ b/src/pystencils/codegen/target.py @@ -30,6 +30,7 @@ class Target(Flag): _GPU = auto() _CUDA = auto() + _HIP = auto() _SYCL = auto() @@ -86,6 +87,12 @@ class Target(Flag): Generate a CUDA kernel for a generic Nvidia GPU. """ + HIP = _GPU | _HIP + """Generic HIP GPU target. + + Generate a HIP kernel for generic AMD or NVidia GPUs. + """ + GPU = CUDA """Alias for `Target.CUDA`, for backward compatibility."""