Skip to content
Snippets Groups Projects
Commit 71ab7480 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

start introduction of Target.HIP

parent 81184635
No related branches found
No related tags found
1 merge request!458HIP Target and Platform
Pipeline #75790 failed
...@@ -383,7 +383,7 @@ class BasePrinter(ABC): ...@@ -383,7 +383,7 @@ class BasePrinter(ABC):
from ...codegen import GpuKernel from ...codegen import GpuKernel
sig_parts = [self._func_prefix] if self._func_prefix is not None else [] sig_parts = [self._func_prefix] if self._func_prefix is not None else []
if isinstance(func, GpuKernel) and func.target == Target.CUDA: if isinstance(func, GpuKernel) and func.target.is_gpu():
sig_parts.append("__global__") sig_parts.append("__global__")
sig_parts += ["void", func.name, f"({params_str})"] sig_parts += ["void", func.name, f"({params_str})"]
signature = " ".join(sig_parts) signature = " ".join(sig_parts)
......
...@@ -593,12 +593,14 @@ class CreateKernelConfig(ConfigBase): ...@@ -593,12 +593,14 @@ class CreateKernelConfig(ConfigBase):
"""Returns either the user-specified JIT compiler, or infers one from the target if none is given.""" """Returns either the user-specified JIT compiler, or infers one from the target if none is given."""
jit: JitBase | None = self.get_option("jit") jit: JitBase | None = self.get_option("jit")
target = self.get_target()
if jit is None: if jit is None:
if self.get_target().is_cpu(): if target.is_cpu():
from ..jit import LegacyCpuJit from ..jit import LegacyCpuJit
return LegacyCpuJit() return LegacyCpuJit()
elif self.get_target() == Target.CUDA: elif target == Target.CUDA:
try: try:
from ..jit.gpu_cupy import CupyJit from ..jit.gpu_cupy import CupyJit
...@@ -609,7 +611,7 @@ class CreateKernelConfig(ConfigBase): ...@@ -609,7 +611,7 @@ class CreateKernelConfig(ConfigBase):
return no_jit return no_jit
elif self.get_target() == Target.SYCL: elif target == Target.SYCL or target == Target.HIP:
from ..jit import no_jit from ..jit import no_jit
return no_jit return no_jit
......
...@@ -398,7 +398,7 @@ class DefaultKernelCreationDriver: ...@@ -398,7 +398,7 @@ class DefaultKernelCreationDriver:
return kernel_ast return kernel_ast
def _get_gpu_indexing(self) -> GpuIndexing | None: def _get_gpu_indexing(self) -> GpuIndexing | None:
if self._target != Target.CUDA: if not self._target.is_gpu():
return None return None
from .gpu_indexing import dim3 from .gpu_indexing import dim3
...@@ -441,7 +441,7 @@ class DefaultKernelCreationDriver: ...@@ -441,7 +441,7 @@ class DefaultKernelCreationDriver:
omit_range_check: bool = gpu_opts.get_option("omit_range_check") omit_range_check: bool = gpu_opts.get_option("omit_range_check")
match self._target: match self._target:
case Target.CUDA: case Target.CUDA | Target.HIP:
from ..backend.platforms import CudaPlatform from ..backend.platforms import CudaPlatform
thread_mapping = ( thread_mapping = (
......
...@@ -30,6 +30,7 @@ class Target(Flag): ...@@ -30,6 +30,7 @@ class Target(Flag):
_GPU = auto() _GPU = auto()
_CUDA = auto() _CUDA = auto()
_HIP = auto()
_SYCL = auto() _SYCL = auto()
...@@ -86,6 +87,12 @@ class Target(Flag): ...@@ -86,6 +87,12 @@ class Target(Flag):
Generate a CUDA kernel for a generic Nvidia GPU. Generate a CUDA kernel for a generic Nvidia GPU.
""" """
HIP = _GPU | _HIP
"""Generic HIP GPU target.
Generate a HIP kernel for generic AMD or NVidia GPUs.
"""
GPU = CUDA GPU = CUDA
"""Alias for `Target.CUDA`, for backward compatibility.""" """Alias for `Target.CUDA`, for backward compatibility."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment