From bd47d36951f246b03b2d43ff24354763d4adf041 Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Wed, 12 Jul 2023 19:42:11 +0200 Subject: [PATCH] Address review comments --- pystencils/datahandling/__init__.py | 9 +++++++-- pystencils/datahandling/serial_datahandling.py | 2 +- pystencils/gpu/gpujit.py | 4 ++-- pystencils_tests/test_datahandling.py | 2 +- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pystencils/datahandling/__init__.py b/pystencils/datahandling/__init__.py index 7f142428c..18053d2d9 100644 --- a/pystencils/datahandling/__init__.py +++ b/pystencils/datahandling/__init__.py @@ -23,7 +23,8 @@ def create_data_handling(domain_size: Tuple[int, ...], default_layout: str = 'SoA', default_target: Target = Target.CPU, parallel: bool = False, - default_ghost_layers: int = 1) -> DataHandling: + default_ghost_layers: int = 1, + device_number: Union[int, None] = None) -> DataHandling: """Creates a data handling instance. Args: @@ -34,6 +35,9 @@ def create_data_handling(domain_size: Tuple[int, ...], default_target: `Target` parallel: if True a parallel domain is created using walberla - each MPI process gets a part of the domain default_ghost_layers: default number of ghost layers if not overwritten in 'add_array' + device_number: If `default_target` is set to 'GPU' and `parallel` is False, a device number should be + specified. If none is given, the device with the largest amount of memory is used. If multiple + devices have the same amount of memory, the one with the lower number is used """ if isinstance(default_target, str): new_target = Target[default_target.upper()] @@ -69,7 +73,8 @@ def create_data_handling(domain_size: Tuple[int, ...], periodicity=periodicity, default_target=default_target, default_layout=default_layout, - default_ghost_layers=default_ghost_layers) + default_ghost_layers=default_ghost_layers, + device_number=device_number) __all__ = ['create_data_handling'] diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py index e0b42771d..0f5ddb431 100644 --- a/pystencils/datahandling/serial_datahandling.py +++ b/pystencils/datahandling/serial_datahandling.py @@ -57,7 +57,7 @@ class SerialDataHandling(DataHandling): if not array_handler: try: if device_number is None: - import cupy + import cupy.cuda.runtime if cupy.cuda.runtime.getDeviceCount() > 0: device_number = sorted(range(cupy.cuda.runtime.getDeviceCount()), key=lambda i: cupy.cuda.Device(i).mem_info[1], reverse=True)[0] diff --git a/pystencils/gpu/gpujit.py b/pystencils/gpu/gpujit.py index 420c3241d..522689241 100644 --- a/pystencils/gpu/gpujit.py +++ b/pystencils/gpu/gpujit.py @@ -75,7 +75,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen try: args, block_and_thread_numbers = cache[key] device = set(a.device.id for a in args if type(a) is cp.ndarray) - assert len(device) == 1 + assert len(device) == 1, "All arrays used by a kernel need to be allocated on the same device" with cp.cuda.Device(device.pop()): func(block_and_thread_numbers['grid'], block_and_thread_numbers['block'], args) except KeyError: @@ -92,7 +92,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen cache[key] = (args, block_and_thread_numbers) cache_values.append(kwargs) # keep objects alive such that ids remain unique device = set(a.device.id for a in args if type(a) is cp.ndarray) - assert len(device) == 1 + assert len(device) == 1, "All arrays used by a kernel need to be allocated on the same device" with cp.cuda.Device(device.pop()): func(block_and_thread_numbers['grid'], block_and_thread_numbers['block'], args) # useful for debugging: diff --git a/pystencils_tests/test_datahandling.py b/pystencils_tests/test_datahandling.py index d31ae1bed..15e9cd74b 100644 --- a/pystencils_tests/test_datahandling.py +++ b/pystencils_tests/test_datahandling.py @@ -16,7 +16,7 @@ except ImportError: pytest = unittest.mock.MagicMock() try: - import cupy + import cupy.cuda.runtime device_numbers = range(cupy.cuda.runtime.getDeviceCount()) except ImportError: device_numbers = [] -- GitLab