From 7db7f592e9691ced3052fa6301f45c89afa54165 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sat, 15 Mar 2025 11:50:17 +0100
Subject: [PATCH] do not use cupy in code that must always be executable

---
 src/pystencils/jit/gpu_cupy.py | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index 69e965325..7065ce815 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -208,12 +208,6 @@ class CupyKernelWrapper(KernelWrapper):
 class CupyJit(JitBase):
 
     def __init__(self, default_block_size: Sequence[int] = (128, 2, 1)):
-        self._runtime_headers: set[str]
-        if cp.cuda.runtime.is_hip:
-            self._runtime_headers = set()
-        else:
-            self._runtime_headers = {"<cstdint>"}
-
         if len(default_block_size) > 3:
             raise ValueError(
                 f"Invalid block size: {default_block_size}. Must be at most three-dimensional."
@@ -234,12 +228,12 @@ class CupyJit(JitBase):
             raise JitError(
                 "The CupyJit just-in-time compiler only accepts GPU kernels generated for CUDA or HIP"
             )
-        
+
         if kernel.target == Target.CUDA and cp.cuda.runtime.is_hip:
             raise JitError(
                 "Cannot compile a CUDA kernel on a HIP-based Cupy installation."
             )
-        
+
         if kernel.target == Target.HIP and not cp.cuda.runtime.is_hip:
             raise JitError(
                 "Cannot compile a HIP kernel on a CUDA-based Cupy installation."
@@ -261,7 +255,13 @@ class CupyJit(JitBase):
         return tuple(options)
 
     def _prelude(self, kfunc: GpuKernel) -> str:
-        headers = self._runtime_headers
+
+        headers: set[str]
+        if cp.cuda.runtime.is_hip:
+            headers = set()
+        else:
+            headers = {"<cstdint>"}
+
         headers |= kfunc.required_headers
 
         if '"pystencils_runtime/half.h"' in headers:
-- 
GitLab