From bd47d36951f246b03b2d43ff24354763d4adf041 Mon Sep 17 00:00:00 2001
From: Michael Kuron <m.kuron@gmx.de>
Date: Wed, 12 Jul 2023 19:42:11 +0200
Subject: [PATCH] Address review comments

---
 pystencils/datahandling/__init__.py            | 9 +++++++--
 pystencils/datahandling/serial_datahandling.py | 2 +-
 pystencils/gpu/gpujit.py                       | 4 ++--
 pystencils_tests/test_datahandling.py          | 2 +-
 4 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/pystencils/datahandling/__init__.py b/pystencils/datahandling/__init__.py
index 7f142428c..18053d2d9 100644
--- a/pystencils/datahandling/__init__.py
+++ b/pystencils/datahandling/__init__.py
@@ -23,7 +23,8 @@ def create_data_handling(domain_size: Tuple[int, ...],
                          default_layout: str = 'SoA',
                          default_target: Target = Target.CPU,
                          parallel: bool = False,
-                         default_ghost_layers: int = 1) -> DataHandling:
+                         default_ghost_layers: int = 1,
+                         device_number: Union[int, None] = None) -> DataHandling:
     """Creates a data handling instance.
 
     Args:
@@ -34,6 +35,9 @@ def create_data_handling(domain_size: Tuple[int, ...],
         default_target: `Target`
         parallel: if True a parallel domain is created using walberla - each MPI process gets a part of the domain
         default_ghost_layers: default number of ghost layers if not overwritten in 'add_array'
+        device_number: If `default_target` is set to 'GPU' and `parallel` is False, a device number should be
+                       specified. If none is given, the device with the largest amount of memory is used. If multiple
+                       devices have the same amount of memory, the one with the lower number is used
     """
     if isinstance(default_target, str):
         new_target = Target[default_target.upper()]
@@ -69,7 +73,8 @@ def create_data_handling(domain_size: Tuple[int, ...],
                                   periodicity=periodicity,
                                   default_target=default_target,
                                   default_layout=default_layout,
-                                  default_ghost_layers=default_ghost_layers)
+                                  default_ghost_layers=default_ghost_layers,
+                                  device_number=device_number)
 
 
 __all__ = ['create_data_handling']
diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py
index e0b42771d..0f5ddb431 100644
--- a/pystencils/datahandling/serial_datahandling.py
+++ b/pystencils/datahandling/serial_datahandling.py
@@ -57,7 +57,7 @@ class SerialDataHandling(DataHandling):
         if not array_handler:
             try:
                 if device_number is None:
-                    import cupy
+                    import cupy.cuda.runtime
                     if cupy.cuda.runtime.getDeviceCount() > 0:
                         device_number = sorted(range(cupy.cuda.runtime.getDeviceCount()),
                                                key=lambda i: cupy.cuda.Device(i).mem_info[1], reverse=True)[0]
diff --git a/pystencils/gpu/gpujit.py b/pystencils/gpu/gpujit.py
index 420c3241d..522689241 100644
--- a/pystencils/gpu/gpujit.py
+++ b/pystencils/gpu/gpujit.py
@@ -75,7 +75,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
         try:
             args, block_and_thread_numbers = cache[key]
             device = set(a.device.id for a in args if type(a) is cp.ndarray)
-            assert len(device) == 1
+            assert len(device) == 1, "All arrays used by a kernel need to be allocated on the same device"
             with cp.cuda.Device(device.pop()):
                 func(block_and_thread_numbers['grid'], block_and_thread_numbers['block'], args)
         except KeyError:
@@ -92,7 +92,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
             cache[key] = (args, block_and_thread_numbers)
             cache_values.append(kwargs)  # keep objects alive such that ids remain unique
             device = set(a.device.id for a in args if type(a) is cp.ndarray)
-            assert len(device) == 1
+            assert len(device) == 1, "All arrays used by a kernel need to be allocated on the same device"
             with cp.cuda.Device(device.pop()):
                 func(block_and_thread_numbers['grid'], block_and_thread_numbers['block'], args)
                 # useful for debugging:
diff --git a/pystencils_tests/test_datahandling.py b/pystencils_tests/test_datahandling.py
index d31ae1bed..15e9cd74b 100644
--- a/pystencils_tests/test_datahandling.py
+++ b/pystencils_tests/test_datahandling.py
@@ -16,7 +16,7 @@ except ImportError:
     pytest = unittest.mock.MagicMock()
 
 try:
-    import cupy
+    import cupy.cuda.runtime
     device_numbers = range(cupy.cuda.runtime.getDeviceCount())
 except ImportError:
     device_numbers = []
-- 
GitLab