From 9e3e8af0a3dc53e056155c0f01f1d5d260f3622c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 13 Jan 2025 19:28:37 +0100
Subject: [PATCH] fix cupy package name

---
 noxfile.py | 23 +++++++++--------------
 1 file changed, 9 insertions(+), 14 deletions(-)

diff --git a/noxfile.py b/noxfile.py
index db8398b00..aa9ad3043 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -4,26 +4,21 @@ from typing import Sequence
 import os
 import nox
 import subprocess
-import re
 
 nox.options.sessions = ["lint", "typecheck", "testsuite"]
 
 
 def get_cuda_version() -> None | tuple[int, ...]:
-    query_args = ["nvcc", "--version"]
-    
+    smi_args = ["nvidia-smi", "--version"]
+
     try:
-        query_result = subprocess.run(query_args, capture_output=True)
+        result = subprocess.run(smi_args, capture_output=True)
     except FileNotFoundError:
         return None
 
-    matches = re.findall(r"release \d+\.\d+", str(query_result.stdout))
-    if matches:
-        match = matches[0]
-        version_string = match.split()[-1]
-        return tuple(int(v) for v in version_string.split("."))
-    else:
-        return None
+    smi_output = str(result.stdout).splitlines()
+    cuda_version = smi_output[-1].split(":")[1].strip()
+    return tuple(int(v) for v in cuda_version.split("."))
 
 
 def editable_install(session: nox.Session, opts: Sequence[str] = ()):
@@ -55,13 +50,13 @@ def typecheck(session: nox.Session):
 def testsuite(session: nox.Session, cupy_version: str | None):
     if cupy_version is not None:
         cuda_version = get_cuda_version()
-        if cuda_version is None or cuda_version[0] < 11:
+        if cuda_version is None or cuda_version[0] not in (11, 12):
             session.skip(
-                "No compatible installation of CUDA found - Need at least CUDA 11"
+                "No compatible installation of CUDA found - Need either CUDA 11 or 12"
             )
 
         cuda_major = cuda_version[0]
-        cupy_package = f"cupy-cuda{cuda_major}=={cupy_version}"
+        cupy_package = f"cupy-cuda{cuda_major}x=={cupy_version}"
         session.install(cupy_package)
 
     editable_install(session, ["alltrafos", "use_cython", "interactive", "testsuite"])
-- 
GitLab