From a01997e443d504e40d1d2f882c593b37a3c7a2cd Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 14 Mar 2025 17:46:52 +0000 Subject: [PATCH] fix remaining test suite --- src/pystencils/boundaries/boundaryhandling.py | 8 +++---- src/pystencils/codegen/target.py | 21 ++++++++++++++++--- .../datahandling/datahandling_interface.py | 2 -- .../datahandling/serial_datahandling.py | 6 +++--- src/pystencils/jit/gpu_cupy.py | 12 ++++++++++- tests/fixtures.py | 18 +++------------- tests/kernelcreation/test_functions.py | 14 ++++++------- tests/kernelcreation/test_gpu.py | 10 ++++----- tests/kernelcreation/test_half_precision.py | 4 ++-- tests/kernelcreation/test_index_kernels.py | 17 +++++---------- tests/kernelcreation/test_iteration_slices.py | 4 ++-- tests/runtime/test_boundary.py | 4 ++-- tests/runtime/test_datahandling.py | 12 +++++------ 13 files changed, 67 insertions(+), 65 deletions(-) diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index 1f6e3d126..58340c3e0 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -123,7 +123,7 @@ class BoundaryHandling: class_ = self.IndexFieldBlockData class_.to_cpu = to_cpu class_.to_gpu = to_gpu - gpu = self._target in data_handling._GPU_LIKE_TARGETS + gpu = self._target.is_gpu() data_handling.add_custom_class(self._index_array_name, class_, cpu=True, gpu=gpu) @property @@ -240,7 +240,7 @@ class BoundaryHandling: if self._dirty: self.prepare() - for b in self._data_handling.iterate(gpu=self._target in self._data_handling._GPU_LIKE_TARGETS): + for b in self._data_handling.iterate(gpu=self._target.is_gpu()): for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items(): kwargs[self._field_name] = b[self._field_name] kwargs['indexField'] = idx_arr @@ -255,7 +255,7 @@ class BoundaryHandling: if self._dirty: self.prepare() - for b in self._data_handling.iterate(gpu=self._target in self._data_handling._GPU_LIKE_TARGETS): + for b in self._data_handling.iterate(gpu=self._target.is_gpu()): for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items(): arguments = kwargs.copy() arguments[self._field_name] = b[self._field_name] @@ -341,7 +341,7 @@ class BoundaryHandling: def _boundary_data_initialization(self, boundary_obj, boundary_data_setter, **kwargs): if boundary_obj.additional_data_init_callback: boundary_obj.additional_data_init_callback(boundary_data_setter, **kwargs) - if self._target in self._data_handling._GPU_LIKE_TARGETS: + if self._target.is_gpu(): self._data_handling.to_gpu(self._index_array_name) class BoundaryInfo(object): diff --git a/src/pystencils/codegen/target.py b/src/pystencils/codegen/target.py index c4b08b95c..5e214430c 100644 --- a/src/pystencils/codegen/target.py +++ b/src/pystencils/codegen/target.py @@ -93,8 +93,8 @@ class Target(Flag): Generate a HIP kernel for generic AMD or NVidia GPUs. """ - GPU = CUDA - """Alias for `Target.CUDA`, for backward compatibility.""" + GPU = CurrentGPU + """Alias for `Target.CurrentGPU`, for backward compatibility.""" SYCL = _SYCL """SYCL kernel target. @@ -106,15 +106,24 @@ class Target(Flag): """ def is_automatic(self) -> bool: + """Determine if this target is a proxy target that is automatically resolved + according to the runtime environment.""" return Target._AUTOMATIC in self def is_cpu(self) -> bool: + """Determine if this target is a CPU target.""" return Target._CPU in self def is_vector_cpu(self) -> bool: + """Determine if this target is a vector CPU target.""" return self.is_cpu() and Target._VECTOR in self def is_gpu(self) -> bool: + """Determine if this target is a GPU target. + + This refers to targets for the CUDA and HIP family of platforms. + `Target.SYCL` is *not* a GPU target. + """ return Target._GPU in self @staticmethod @@ -128,6 +137,11 @@ class Target(Flag): @staticmethod def auto_gpu() -> Target: + """Return the GPU target available in the current runtime environment. + + Raises: + RuntimeError: If `cupy` is not installed and therefore no GPU runtime is available. + """ try: import cupy @@ -140,10 +154,11 @@ class Target(Flag): @staticmethod def available_targets() -> list[Target]: + """List available""" targets = [Target.GenericCPU] try: import cupy # noqa: F401 - targets.append(Target.CUDA) + targets.append(Target.auto_gpu()) except ImportError: pass diff --git a/src/pystencils/datahandling/datahandling_interface.py b/src/pystencils/datahandling/datahandling_interface.py index 867bbf062..a6b1fcb55 100644 --- a/src/pystencils/datahandling/datahandling_interface.py +++ b/src/pystencils/datahandling/datahandling_interface.py @@ -17,8 +17,6 @@ class DataHandling(ABC): 'gather' function that has collects (parts of the) distributed data on a single process. """ - _GPU_LIKE_TARGETS = [Target.GPU] - # ---------------------------- Adding and accessing data ----------------------------------------------------------- @property @abstractmethod diff --git a/src/pystencils/datahandling/serial_datahandling.py b/src/pystencils/datahandling/serial_datahandling.py index 73b749ca4..dc6904c3a 100644 --- a/src/pystencils/datahandling/serial_datahandling.py +++ b/src/pystencils/datahandling/serial_datahandling.py @@ -110,7 +110,7 @@ class SerialDataHandling(DataHandling): if layout is None: layout = self.default_layout if gpu is None: - gpu = self.default_target in self._GPU_LIKE_TARGETS + gpu = self.default_target.is_gpu() kwargs = { 'shape': tuple(s + 2 * ghost_layers for s in self._domainSize), @@ -241,7 +241,7 @@ class SerialDataHandling(DataHandling): def swap(self, name1, name2, gpu=None): if gpu is None: - gpu = self.default_target in self._GPU_LIKE_TARGETS + gpu = self.default_target.is_gpu() arr = self.gpu_arrays if gpu else self.cpu_arrays arr[name1], arr[name2] = arr[name2], arr[name1] @@ -292,7 +292,7 @@ class SerialDataHandling(DataHandling): if target is None: target = self.default_target - if not (target.is_cpu() or target == Target.CUDA): + if not (target.is_cpu() or target.is_gpu()): raise ValueError(f"Unsupported target: {target}") if not hasattr(names, '__len__') or type(names) is str: diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index 1461bfd7a..69e965325 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -231,9 +231,19 @@ class CupyJit(JitBase): ) if not isinstance(kernel, GpuKernel): - raise ValueError( + raise JitError( "The CupyJit just-in-time compiler only accepts GPU kernels generated for CUDA or HIP" ) + + if kernel.target == Target.CUDA and cp.cuda.runtime.is_hip: + raise JitError( + "Cannot compile a CUDA kernel on a HIP-based Cupy installation." + ) + + if kernel.target == Target.HIP and not cp.cuda.runtime.is_hip: + raise JitError( + "Cannot compile a HIP kernel on a CUDA-based Cupy installation." + ) options = self._compiler_options() prelude = self._prelude(kernel) diff --git a/tests/fixtures.py b/tests/fixtures.py index a19519988..a4c77f550 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -18,19 +18,7 @@ from types import ModuleType import pystencils as ps -AVAILABLE_TARGETS = [ps.Target.GenericCPU] - -try: - import cupy - - if cupy.cuda.runtime.is_hip: - AVAILABLE_TARGETS += [ps.Target.HIP] - else: - AVAILABLE_TARGETS += [ps.Target.CUDA] -except ImportError: - pass - -AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets() +AVAILABLE_TARGETS = ps.Target.available_targets() TARGET_IDS = [t.name for t in AVAILABLE_TARGETS] @@ -75,9 +63,9 @@ def xp(target: ps.Target) -> ModuleType: """Primary array module for the current target. Returns: - `cupy` if `target == Target.CUDA`, and `numpy` otherwise + `cupy` if `target.is_gpu()`, and `numpy` otherwise """ - if target == ps.Target.CUDA: + if target.is_gpu(): import cupy as xp return xp diff --git a/tests/kernelcreation/test_functions.py b/tests/kernelcreation/test_functions.py index a4d154d4b..182a59005 100644 --- a/tests/kernelcreation/test_functions.py +++ b/tests/kernelcreation/test_functions.py @@ -106,14 +106,14 @@ def function_domain(function_name, dtype): case "pow": return np.concatenate( [ - [0., 1., 1.], - rng.uniform(-1., 1., 8), - rng.uniform(0., 5., 8), + [0.0, 1.0, 1.0], + rng.uniform(-1.0, 1.0, 8), + rng.uniform(0.0, 5.0, 8), ] ).astype(dtype), np.concatenate( [ - [1., 0., 2.], - np.arange(2., 10., 1.), + [1.0, 0.0, 2.0], + np.arange(2.0, 10.0, 1.0), rng.uniform(-2.0, 2.0, 8), ] ).astype( @@ -211,14 +211,14 @@ def test_binary_functions(gen_config, xp, function_name, dtype, function_domain) dtype_and_target_for_integer_funcs = pytest.mark.parametrize( "dtype, target", - list(product([np.int32], [t for t in AVAIL_TARGETS if t is not Target.CUDA])) + list(product([np.int32], [t for t in AVAIL_TARGETS if not t.is_gpu()])) + list( product( [np.int64], [ t for t in AVAIL_TARGETS - if t not in (Target.X86_SSE, Target.X86_AVX, Target.CUDA) + if t not in (Target.X86_SSE, Target.X86_AVX) and not t.is_gpu() ], ) ), diff --git a/tests/kernelcreation/test_gpu.py b/tests/kernelcreation/test_gpu.py index f1905b1fc..8d943e8fd 100644 --- a/tests/kernelcreation/test_gpu.py +++ b/tests/kernelcreation/test_gpu.py @@ -45,7 +45,7 @@ def test_indexing_options_3d( + src[0, 0, 1], ) - cfg = CreateKernelConfig(target=Target.CUDA) + cfg = CreateKernelConfig(target=Target.CurrentGPU) cfg.gpu.indexing_scheme = indexing_scheme cfg.gpu.omit_range_check = omit_range_check cfg.gpu.manual_launch_grid = manual_grid @@ -91,7 +91,7 @@ def test_indexing_options_2d( + src[0, 1] ) - cfg = CreateKernelConfig(target=Target.CUDA) + cfg = CreateKernelConfig(target=Target.CurrentGPU) cfg.gpu.indexing_scheme = indexing_scheme cfg.gpu.omit_range_check = omit_range_check cfg.gpu.manual_launch_grid = manual_grid @@ -126,7 +126,7 @@ def test_invalid_indexing_schemes(): src, dst = fields("src, dst: [4D]") asm = Assignment(src.center(0), dst.center(0)) - cfg = CreateKernelConfig(target=Target.CUDA) + cfg = CreateKernelConfig(target=Target.CurrentGPU) cfg.gpu.indexing_scheme = "linear3d" with pytest.raises(Exception): @@ -241,7 +241,7 @@ def test_ghost_layer(): ghost_layers = [(1, 2), (2, 1)] config = CreateKernelConfig() - config.target = Target.CUDA + config.target = Target.CurrentGPU config.ghost_layers = ghost_layers config.gpu.indexing_scheme = "blockwise4d" @@ -270,7 +270,7 @@ def test_setting_value(): update_rule = [Assignment(f(0), sp.Symbol("value"))] config = CreateKernelConfig() - config.target = Target.CUDA + config.target = Target.CurrentGPU config.iteration_slice = iteration_slice config.gpu.indexing_scheme = "blockwise4d" diff --git a/tests/kernelcreation/test_half_precision.py b/tests/kernelcreation/test_half_precision.py index a9745459d..5dbe2180e 100644 --- a/tests/kernelcreation/test_half_precision.py +++ b/tests/kernelcreation/test_half_precision.py @@ -5,7 +5,7 @@ import numpy as np import pystencils as ps -@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) +@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.CurrentGPU)) def test_half_precison(target): if target == ps.Target.CPU: if not platform.machine() in ['arm64', 'aarch64']: @@ -14,7 +14,7 @@ def test_half_precison(target): if 'clang' not in ps.cpu.cpujit.get_compiler_config()['command']: pytest.xfail("skipping half precision because clang compiler is not used") - if target == ps.Target.GPU: + if target.is_gpu(): pytest.importorskip("cupy") dh = ps.create_data_handling(domain_size=(10, 10), default_target=target) diff --git a/tests/kernelcreation/test_index_kernels.py b/tests/kernelcreation/test_index_kernels.py index 569c0ab6a..bda0ef273 100644 --- a/tests/kernelcreation/test_index_kernels.py +++ b/tests/kernelcreation/test_index_kernels.py @@ -5,14 +5,7 @@ from pystencils import Assignment, Field, FieldType, AssignmentCollection, Targe from pystencils import create_kernel, CreateKernelConfig -@pytest.mark.parametrize("target", [Target.CPU, Target.GPU]) -def test_indexed_kernel(target): - if target == Target.GPU: - cp = pytest.importorskip("cupy") - xp = cp - else: - xp = np - +def test_indexed_kernel(target, xp): arr = xp.zeros((3, 4)) dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)], align=True) @@ -21,8 +14,8 @@ def test_indexed_kernel(target): cpu_index_arr[1] = (1, 3, 42.0) cpu_index_arr[2] = (2, 1, 5.0) - if target == Target.GPU: - gpu_index_arr = cp.empty(cpu_index_arr.shape, cpu_index_arr.dtype) + if target.is_gpu(): + gpu_index_arr = xp.empty(cpu_index_arr.shape, cpu_index_arr.dtype) gpu_index_arr.set(cpu_index_arr) index_arr = gpu_index_arr else: @@ -40,8 +33,8 @@ def test_indexed_kernel(target): kernel(f=arr, index=index_arr) - if target == Target.GPU: - arr = cp.asnumpy(arr) + if target.is_gpu(): + arr = xp.asnumpy(arr) for i in range(cpu_index_arr.shape[0]): np.testing.assert_allclose(arr[cpu_index_arr[i]['x'], cpu_index_arr[i]['y']], cpu_index_arr[i]['value'], atol=1e-13) diff --git a/tests/kernelcreation/test_iteration_slices.py b/tests/kernelcreation/test_iteration_slices.py index b1f2da576..2b3a8ebf0 100644 --- a/tests/kernelcreation/test_iteration_slices.py +++ b/tests/kernelcreation/test_iteration_slices.py @@ -144,7 +144,7 @@ def test_triangle_pattern(gen_config: CreateKernelConfig, xp): islice = make_slice[:, slow_counter:] gen_config = replace(gen_config, iteration_slice=islice) - if gen_config.target == Target.CUDA: + if gen_config.target.is_gpu(): gen_config.gpu.manual_launch_grid = True kernel = create_kernel(update, gen_config).compile() @@ -177,7 +177,7 @@ def test_red_black_pattern(gen_config: CreateKernelConfig, xp): islice = make_slice[:, start::2] gen_config.iteration_slice = islice - if gen_config.target == Target.CUDA: + if gen_config.target.is_gpu(): gen_config.gpu.manual_launch_grid = True try: diff --git a/tests/runtime/test_boundary.py b/tests/runtime/test_boundary.py index fb8f827e8..226510b83 100644 --- a/tests/runtime/test_boundary.py +++ b/tests/runtime/test_boundary.py @@ -98,7 +98,7 @@ def test_kernel_vs_copy_boundary(): def test_boundary_gpu(): pytest.importorskip('cupy') - dh = SerialDataHandling(domain_size=(7, 7), default_target=Target.GPU) + dh = SerialDataHandling(domain_size=(7, 7), default_target=Target.CurrentGPU) src = dh.add_array('src') dh.fill("src", 0.0, ghost_layers=True) dh.fill("src", 1.0, ghost_layers=False) @@ -111,7 +111,7 @@ def test_boundary_gpu(): name="boundary_handling_cpu", target=Target.CPU) boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil, - name="boundary_handling_gpu", target=Target.GPU) + name="boundary_handling_gpu", target=Target.CurrentGPU) neumann = Neumann() for d in ('N', 'S', 'W', 'E'): diff --git a/tests/runtime/test_datahandling.py b/tests/runtime/test_datahandling.py index 9d7ff924e..9e7c73cac 100644 --- a/tests/runtime/test_datahandling.py +++ b/tests/runtime/test_datahandling.py @@ -118,7 +118,7 @@ def synchronization(dh, test_gpu=False): def kernel_execution_jacobi(dh, target): - test_gpu = target == Target.GPU + test_gpu = target == Target.CurrentGPU dh.add_array('f', gpu=test_gpu) dh.add_array('tmp', gpu=test_gpu) @@ -219,15 +219,15 @@ def test_kernel(): try: import cupy dh = create_data_handling(domain_size=domain_shape, periodicity=True) - kernel_execution_jacobi(dh, Target.GPU) + kernel_execution_jacobi(dh, Target.CurrentGPU) except ImportError: pass -@pytest.mark.parametrize('target', (Target.CPU, Target.GPU)) +@pytest.mark.parametrize('target', (Target.CPU, Target.CurrentGPU)) def test_kernel_param(target): for domain_shape in [(4, 5), (3, 4, 5)]: - if target == Target.GPU: + if target == Target.CurrentGPU: pytest.importorskip('cupy') dh = create_data_handling(domain_size=domain_shape, periodicity=True, default_target=target) @@ -262,7 +262,7 @@ def test_add_arrays(): def test_add_arrays_with_layout(shape, layout): pytest.importorskip('cupy') - dh = create_data_handling(domain_size=shape, default_layout=layout, default_target=ps.Target.GPU) + dh = create_data_handling(domain_size=shape, default_layout=layout, default_target=ps.Target.CurrentGPU) f1 = dh.add_array("f1", values_per_cell=19) dh.fill(f1.name, 1.0) @@ -392,8 +392,6 @@ def test_array_handler(device_number): empty = array_handler.empty(shape=size, order="F") assert empty.strides == (8, 16) - random_array = array_handler.randn(size) - cpu_array = np.empty((20, 40), dtype=np.float64) gpu_array = array_handler.to_gpu(cpu_array) -- GitLab