From 4f1d2c94c2146d9a04f8a6c7a446653bc123543f Mon Sep 17 00:00:00 2001
From: Michael Kuron <mkuron@icp.uni-stuttgart.de>
Date: Sat, 8 Jul 2023 08:11:28 +0200
Subject: [PATCH] Remove pystencils.GPU_DEVICE

---
 lbmpy/max_domain_size_info.py               | 20 +++++---------------
 lbmpy_tests/test_gpu_block_size_limiting.py |  2 +-
 2 files changed, 6 insertions(+), 16 deletions(-)

diff --git a/lbmpy/max_domain_size_info.py b/lbmpy/max_domain_size_info.py
index 3e6b9615..23490b91 100644
--- a/lbmpy/max_domain_size_info.py
+++ b/lbmpy/max_domain_size_info.py
@@ -26,7 +26,6 @@ Examples:
 import warnings
 
 import numpy as np
-import pystencils
 
 # Optional packages cpuinfo, cupy and psutil for hardware queries
 try:
@@ -36,19 +35,8 @@ except ImportError:
 
 try:
     import cupy
-    device = cupy.cuda.Device(pystencils.GPU_DEVICE)
 except ImportError:
     cupy = None
-    device = None
-
-if cupy:
-    try:
-        device = cupy.cuda.Device(pystencils.GPU_DEVICE)
-    except AttributeError:
-        warnings.warn("You are using an old pystencils version. Consider updating it")
-        device = cupy.cuda.Device(0)
-else:
-    device = None
 
 try:
     from psutil import virtual_memory
@@ -125,9 +113,11 @@ def memory_sizes_of_current_machine():
         if 'l3_cache_size' in cpu_info:
             result['L3'] = cpu_info['l3_cache_size']
 
-    if device:
-        size = device.mem_info[1] / (1024 * 1024)
-        result['GPU'] = "{0:.0f} MB".format(size)
+    if cupy:
+        for i in range(cupy.cuda.runtime.getDeviceCount()):
+            device = cupy.cuda.Device(i)
+            size = device.mem_info[1] / (1024 * 1024)
+            result[f'GPU{i}'] = "{0:.0f} MB".format(size)
 
     if virtual_memory:
         mem = virtual_memory()
diff --git a/lbmpy_tests/test_gpu_block_size_limiting.py b/lbmpy_tests/test_gpu_block_size_limiting.py
index f3bfc805..2f8964bb 100644
--- a/lbmpy_tests/test_gpu_block_size_limiting.py
+++ b/lbmpy_tests/test_gpu_block_size_limiting.py
@@ -17,5 +17,5 @@ def test_gpu_block_size_limiting():
     kernel = ast.compile()
     assert all(b < too_large for b in limited_block_size['block'])
     bs = [too_large, too_large, too_large]
-    ast.indexing.limit_block_size_by_register_restriction(bs, kernel.num_regs)
+    bs = ast.indexing.limit_block_size_by_register_restriction(bs, kernel.num_regs)
     assert all(b < too_large for b in bs)
-- 
GitLab