From 9e3e8af0a3dc53e056155c0f01f1d5d260f3622c Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 13 Jan 2025 19:28:37 +0100 Subject: [PATCH] fix cupy package name --- noxfile.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/noxfile.py b/noxfile.py index db8398b00..aa9ad3043 100644 --- a/noxfile.py +++ b/noxfile.py @@ -4,26 +4,21 @@ from typing import Sequence import os import nox import subprocess -import re nox.options.sessions = ["lint", "typecheck", "testsuite"] def get_cuda_version() -> None | tuple[int, ...]: - query_args = ["nvcc", "--version"] - + smi_args = ["nvidia-smi", "--version"] + try: - query_result = subprocess.run(query_args, capture_output=True) + result = subprocess.run(smi_args, capture_output=True) except FileNotFoundError: return None - matches = re.findall(r"release \d+\.\d+", str(query_result.stdout)) - if matches: - match = matches[0] - version_string = match.split()[-1] - return tuple(int(v) for v in version_string.split(".")) - else: - return None + smi_output = str(result.stdout).splitlines() + cuda_version = smi_output[-1].split(":")[1].strip() + return tuple(int(v) for v in cuda_version.split(".")) def editable_install(session: nox.Session, opts: Sequence[str] = ()): @@ -55,13 +50,13 @@ def typecheck(session: nox.Session): def testsuite(session: nox.Session, cupy_version: str | None): if cupy_version is not None: cuda_version = get_cuda_version() - if cuda_version is None or cuda_version[0] < 11: + if cuda_version is None or cuda_version[0] not in (11, 12): session.skip( - "No compatible installation of CUDA found - Need at least CUDA 11" + "No compatible installation of CUDA found - Need either CUDA 11 or 12" ) cuda_major = cuda_version[0] - cupy_package = f"cupy-cuda{cuda_major}=={cupy_version}" + cupy_package = f"cupy-cuda{cuda_major}x=={cupy_version}" session.install(cupy_package) editable_install(session, ["alltrafos", "use_cython", "interactive", "testsuite"]) -- GitLab