From 65a865b7209be505301ce1cba23c7bb758590d71 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 14 Jan 2025 09:50:38 +0100
Subject: [PATCH] Install cupy in docs session

---
 noxfile.py | 55 ++++++++++++++++++++++++++++++++----------------------
 pytest.ini |  2 ++
 2 files changed, 35 insertions(+), 22 deletions(-)

diff --git a/noxfile.py b/noxfile.py
index d54815c43..121f4f575 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -9,9 +9,9 @@ import re
 nox.options.sessions = ["lint", "typecheck", "testsuite"]
 
 
-def get_cuda_version() -> None | tuple[int, ...]:
+def get_cuda_version(session: nox.Session) -> None | tuple[int, ...]:
     query_args = ["nvcc", "--version"]
-    
+
     try:
         query_result = subprocess.run(query_args, capture_output=True)
     except FileNotFoundError:
@@ -21,9 +21,34 @@ def get_cuda_version() -> None | tuple[int, ...]:
     if matches:
         match = matches[0]
         version_string = match.split()[-1]
-        return tuple(int(v) for v in version_string.split("."))
-    else:
-        return None
+        try:
+            return tuple(int(v) for v in version_string.split("."))
+        except ValueError:
+            pass
+    
+    session.warn("nvcc was found, but I am unable to determine the CUDA version.")
+    return None
+
+
+def install_cupy(
+    session: nox.Session, cupy_version: str, skip_if_no_cuda: bool = False
+):
+    if cupy_version is not None:
+        cuda_version = get_cuda_version(session)
+        if cuda_version is None or cuda_version[0] not in (11, 12):
+            if skip_if_no_cuda:
+                session.skip(
+                    "No compatible installation of CUDA found - Need either CUDA 11 or 12"
+                )
+            else:
+                session.warn(
+                    "Running without cupy: no compatbile installation of CUDA found. Need either CUDA 11 or 12."
+                )
+                return
+
+        cuda_major = cuda_version[0]
+        cupy_package = f"cupy-cuda{cuda_major}x=={cupy_version}"
+        session.install(cupy_package)
 
 
 def editable_install(session: nox.Session, opts: Sequence[str] = ()):
@@ -54,15 +79,7 @@ def typecheck(session: nox.Session):
 @nox.session(python="3.10", tags=["test"])
 def testsuite(session: nox.Session, cupy_version: str | None):
     if cupy_version is not None:
-        cuda_version = get_cuda_version()
-        if cuda_version is None or cuda_version[0] not in (11, 12):
-            session.skip(
-                "No compatible installation of CUDA found - Need either CUDA 11 or 12"
-            )
-
-        cuda_major = cuda_version[0]
-        cupy_package = f"cupy-cuda{cuda_major}x=={cupy_version}"
-        session.install(cupy_package)
+        install_cupy(session, cupy_version, skip_if_no_cuda=True)
 
     editable_install(session, ["alltrafos", "use_cython", "interactive", "testsuite"])
 
@@ -88,13 +105,7 @@ def testsuite(session: nox.Session, cupy_version: str | None):
 @nox.session(python=["3.10"], tags=["docs"])
 def docs(session: nox.Session):
     """Build the documentation pages"""
+    install_cupy(session, "12.3")
     editable_install(session, ["doc"])
     session.chdir("docs")
-    session.run(
-        "make",
-        "html",
-        external=True,
-        env={
-            "SPHINXOPTS": "-W --keep-going"
-        }
-    )
+    session.run("make", "html", external=True, env={"SPHINXOPTS": "-W --keep-going"})
diff --git a/pytest.ini b/pytest.ini
index 281eaa21e..707a43b45 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -23,6 +23,7 @@ filterwarnings =
        ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc':DeprecationWarning
        ignore:Animation was deleted without rendering anything:UserWarning
 
+# Coverage Configuration
 [run]
 branch = True
 source = src/pystencils
@@ -31,6 +32,7 @@ source = src/pystencils
 omit = doc/*
        tests/*
        setup.py
+       noxfile.py
        quicktest.py
        conftest.py
        versioneer.py
-- 
GitLab