From de1dc39e761e2fd151c12e87952a897a150674d6 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 14 Mar 2025 16:47:42 +0000
Subject: [PATCH] start fixing JIT for HIP

---
 src/pystencils/codegen/config.py        |  6 ++++--
 src/pystencils/codegen/target.py        | 12 ++++++++++++
 src/pystencils/jit/gpu_cupy.py          | 10 +++++++---
 tests/fixtures.py                       |  5 ++++-
 tests/kernelcreation/test_buffer_gpu.py | 24 ++++++++++++------------
 5 files changed, 39 insertions(+), 18 deletions(-)

diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 821bb7e07..a765dea2e 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -586,6 +586,8 @@ class CreateKernelConfig(ConfigBase):
         match t:
             case Target.CurrentCPU:
                 return Target.auto_cpu()
+            case Target.CurrentGPU:
+                return Target.auto_gpu()
             case _:
                 return t
 
@@ -600,7 +602,7 @@ class CreateKernelConfig(ConfigBase):
                 from ..jit import LegacyCpuJit
 
                 return LegacyCpuJit()
-            elif target == Target.CUDA:
+            elif target == Target.CUDA or target == Target.HIP:
                 try:
                     from ..jit.gpu_cupy import CupyJit
                     
@@ -611,7 +613,7 @@ class CreateKernelConfig(ConfigBase):
 
                     return no_jit
 
-            elif target == Target.SYCL or target == Target.HIP:
+            elif target == Target.SYCL:
                 from ..jit import no_jit
 
                 return no_jit
diff --git a/src/pystencils/codegen/target.py b/src/pystencils/codegen/target.py
index 03364af28..c4b08b95c 100644
--- a/src/pystencils/codegen/target.py
+++ b/src/pystencils/codegen/target.py
@@ -126,6 +126,18 @@ class Target(Flag):
         else:
             return Target.GenericCPU
         
+    @staticmethod
+    def auto_gpu() -> Target:
+        try:
+            import cupy
+
+            if cupy.cuda.runtime.is_hip:
+                return Target.HIP
+            else:
+                return Target.CUDA
+        except ImportError:
+            raise RuntimeError("Cannot infer GPU target since cupy is not installed.")
+        
     @staticmethod
     def available_targets() -> list[Target]:
         targets = [Target.GenericCPU]
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index d45abf878..1461bfd7a 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -208,7 +208,11 @@ class CupyKernelWrapper(KernelWrapper):
 class CupyJit(JitBase):
 
     def __init__(self, default_block_size: Sequence[int] = (128, 2, 1)):
-        self._runtime_headers = {"<cstdint>"}
+        self._runtime_headers: set[str]
+        if cp.cuda.runtime.is_hip:
+            self._runtime_headers = set()
+        else:
+            self._runtime_headers = {"<cstdint>"}
 
         if len(default_block_size) > 3:
             raise ValueError(
@@ -226,9 +230,9 @@ class CupyJit(JitBase):
                 "`cupy` is not installed: just-in-time-compilation of CUDA kernels is unavailable."
             )
 
-        if not isinstance(kernel, GpuKernel) or kernel.target != Target.CUDA:
+        if not isinstance(kernel, GpuKernel):
             raise ValueError(
-                "The CupyJit just-in-time compiler only accepts kernels generated for CUDA or HIP"
+                "The CupyJit just-in-time compiler only accepts GPU kernels generated for CUDA or HIP"
             )
 
         options = self._compiler_options()
diff --git a/tests/fixtures.py b/tests/fixtures.py
index ba2593f76..a19519988 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -23,7 +23,10 @@ AVAILABLE_TARGETS = [ps.Target.GenericCPU]
 try:
     import cupy
 
-    AVAILABLE_TARGETS += [ps.Target.CUDA]
+    if cupy.cuda.runtime.is_hip:
+        AVAILABLE_TARGETS += [ps.Target.HIP]
+    else:
+        AVAILABLE_TARGETS += [ps.Target.CUDA]
 except ImportError:
     pass
 
diff --git a/tests/kernelcreation/test_buffer_gpu.py b/tests/kernelcreation/test_buffer_gpu.py
index 0b5019fba..bd9d2156b 100644
--- a/tests/kernelcreation/test_buffer_gpu.py
+++ b/tests/kernelcreation/test_buffer_gpu.py
@@ -58,7 +58,7 @@ def test_full_scalar_field():
 
         pack_eqs = [Assignment(buffer.center(), src_field.center())]
 
-        config = CreateKernelConfig(target=pystencils.Target.GPU)
+        config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
         pack_ast = create_kernel(pack_eqs, config=config)
 
         pack_kernel = pack_ast.compile()
@@ -66,7 +66,7 @@ def test_full_scalar_field():
 
         unpack_eqs = [Assignment(dst_field.center(), buffer.center())]
 
-        config = CreateKernelConfig(target=pystencils.Target.GPU)
+        config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
         unpack_ast = create_kernel(unpack_eqs, config=config)
 
         unpack_kernel = unpack_ast.compile()
@@ -94,7 +94,7 @@ def test_field_slice():
 
             pack_eqs = [Assignment(buffer.center(), src_field.center())]
 
-            config = CreateKernelConfig(target=pystencils.Target.GPU)
+            config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
             pack_ast = create_kernel(pack_eqs, config=config)
 
             pack_kernel = pack_ast.compile()
@@ -103,7 +103,7 @@ def test_field_slice():
             # Unpack into ghost layer of dst_field in N direction
             unpack_eqs = [Assignment(dst_field.center(), buffer.center())]
 
-            config = CreateKernelConfig(target=pystencils.Target.GPU)
+            config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
             unpack_ast = create_kernel(unpack_eqs, config=config)
 
             unpack_kernel = unpack_ast.compile()
@@ -131,7 +131,7 @@ def test_all_cell_values():
             eq = Assignment(buffer(idx), src_field(idx))
             pack_eqs.append(eq)
 
-        config = CreateKernelConfig(target=pystencils.Target.GPU)
+        config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
         pack_code = create_kernel(pack_eqs, config=config)
         pack_kernel = pack_code.compile()
 
@@ -143,7 +143,7 @@ def test_all_cell_values():
             eq = Assignment(dst_field(idx), buffer(idx))
             unpack_eqs.append(eq)
 
-        config = CreateKernelConfig(target=pystencils.Target.GPU)
+        config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
         unpack_ast = create_kernel(unpack_eqs, config=config)
         unpack_kernel = unpack_ast.compile()
         unpack_kernel(buffer=gpu_buffer_arr, dst_field=gpu_dst_arr)
@@ -173,7 +173,7 @@ def test_subset_cell_values():
             pack_eqs.append(eq)
 
         pack_types = {'src_field': gpu_src_arr.dtype, 'buffer': gpu_buffer_arr.dtype}
-        config = CreateKernelConfig(target=pystencils.Target.GPU)
+        config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
         pack_ast = create_kernel(pack_eqs, config=config)
         pack_kernel = pack_ast.compile()
         pack_kernel(buffer=gpu_buffer_arr, src_field=gpu_src_arr)
@@ -185,7 +185,7 @@ def test_subset_cell_values():
             unpack_eqs.append(eq)
 
         unpack_types = {'dst_field': gpu_dst_arr.dtype, 'buffer': gpu_buffer_arr.dtype}
-        config = CreateKernelConfig(target=pystencils.Target.GPU)
+        config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
         unpack_ast = create_kernel(unpack_eqs, config=config)
         unpack_kernel = unpack_ast.compile()
 
@@ -215,7 +215,7 @@ def test_field_layouts():
                 pack_eqs.append(eq)
 
             pack_types = {'src_field': gpu_src_arr.dtype, 'buffer': gpu_buffer_arr.dtype}
-            config = CreateKernelConfig(target=pystencils.Target.GPU)
+            config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
             pack_ast = create_kernel(pack_eqs, config=config)
             pack_kernel = pack_ast.compile()
 
@@ -228,7 +228,7 @@ def test_field_layouts():
                 unpack_eqs.append(eq)
 
             unpack_types = {'dst_field': gpu_dst_arr.dtype, 'buffer': gpu_buffer_arr.dtype}
-            config = CreateKernelConfig(target=pystencils.Target.GPU)
+            config = CreateKernelConfig(target=pystencils.Target.CurrentGPU)
             unpack_ast = create_kernel(unpack_eqs, config=config)
             unpack_kernel = unpack_ast.compile()
 
@@ -299,7 +299,7 @@ def test_iteration_slices(gpu_indexing):
         gpu_src_arr.set(src_arr)
         gpu_dst_arr.fill(0)
 
-        config = CreateKernelConfig(target=Target.GPU, iteration_slice=pack_slice)
+        config = CreateKernelConfig(target=Target.CurrentGPU, iteration_slice=pack_slice)
 
         pack_code = create_kernel(pack_eqs, config=config)
         pack_kernel = pack_code.compile()
@@ -311,7 +311,7 @@ def test_iteration_slices(gpu_indexing):
             eq = Assignment(dst_field(idx), buffer(idx))
             unpack_eqs.append(eq)
 
-        config = CreateKernelConfig(target=Target.GPU, iteration_slice=pack_slice)
+        config = CreateKernelConfig(target=Target.CurrentGPU, iteration_slice=pack_slice)
 
         unpack_code = create_kernel(unpack_eqs, config=config)
         unpack_kernel = unpack_code.compile()
-- 
GitLab