Skip to content
Snippets Groups Projects
Commit 057aeb58 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'FixGPUIndexing' into 'master'

[BUGFIX] GPU slicing

Closes #90

See merge request !396
parents 76e63d1a 54b01e22
Branches
Tags
1 merge request!396[BUGFIX] GPU slicing
Pipeline #67327 passed with stages
in 17 minutes and 40 seconds
...@@ -224,6 +224,9 @@ class BlockIndexing(AbstractIndexing): ...@@ -224,6 +224,9 @@ class BlockIndexing(AbstractIndexing):
assert len(self._iteration_space) == len(arr_shape), "Iteration space must be equal to the array shape" assert len(self._iteration_space) == len(arr_shape), "Iteration space must be equal to the array shape"
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape) numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice] end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice]
for i, s in enumerate(numeric_iteration_slice):
if s.step and s.step != 1:
end[i] = div_ceil(s.stop - s.start, s.step) + s.start
if self._dim < 4: if self._dim < 4:
conditions = [c < e for c, e in zip(self.coordinates, end)] conditions = [c < e for c, e in zip(self.coordinates, end)]
......
import pytest import pytest
import numpy as np import numpy as np
import cupy as cp
import sympy as sp import sympy as sp
import math
from scipy.ndimage import convolve from scipy.ndimage import convolve
from pystencils import Assignment, Field, fields, CreateKernelConfig, create_kernel, Target from pystencils import Assignment, Field, fields, CreateKernelConfig, create_kernel, Target, get_code_str
from pystencils.gpu import BlockIndexing from pystencils.gpu import BlockIndexing
from pystencils.simp import sympy_cse_on_assignment_list from pystencils.simp import sympy_cse_on_assignment_list
from pystencils.slicing import add_ghost_layers, make_slice, remove_ghost_layers, normalize_slice from pystencils.slicing import add_ghost_layers, make_slice, remove_ghost_layers, normalize_slice
try: try:
import cupy import cupy as cp
device_numbers = range(cupy.cuda.runtime.getDeviceCount()) device_numbers = range(cp.cuda.runtime.getDeviceCount())
except ImportError: except ImportError:
device_numbers = [] device_numbers = []
cp = None
def test_averaging_kernel(): def test_averaging_kernel():
pytest.importorskip('cupy')
size = (40, 55) size = (40, 55)
src_arr = np.random.rand(*size) src_arr = np.random.rand(*size)
src_arr = add_ghost_layers(src_arr) src_arr = add_ghost_layers(src_arr)
...@@ -44,6 +46,7 @@ def test_averaging_kernel(): ...@@ -44,6 +46,7 @@ def test_averaging_kernel():
def test_variable_sized_fields(): def test_variable_sized_fields():
pytest.importorskip('cupy')
src_field = Field.create_generic('src', spatial_dimensions=2) src_field = Field.create_generic('src', spatial_dimensions=2)
dst_field = Field.create_generic('dst', spatial_dimensions=2) dst_field = Field.create_generic('dst', spatial_dimensions=2)
...@@ -71,6 +74,7 @@ def test_variable_sized_fields(): ...@@ -71,6 +74,7 @@ def test_variable_sized_fields():
def test_multiple_index_dimensions(): def test_multiple_index_dimensions():
pytest.importorskip('cupy')
"""Sums along the last axis of a numpy array""" """Sums along the last axis of a numpy array"""
src_size = (7, 6, 4) src_size = (7, 6, 4)
dst_size = src_size[:2] dst_size = src_size[:2]
...@@ -103,6 +107,7 @@ def test_multiple_index_dimensions(): ...@@ -103,6 +107,7 @@ def test_multiple_index_dimensions():
def test_ghost_layer(): def test_ghost_layer():
pytest.importorskip('cupy')
size = (6, 5) size = (6, 5)
src_arr = np.ones(size) src_arr = np.ones(size)
dst_arr = np.zeros_like(src_arr) dst_arr = np.zeros_like(src_arr)
...@@ -127,6 +132,7 @@ def test_ghost_layer(): ...@@ -127,6 +132,7 @@ def test_ghost_layer():
def test_setting_value(): def test_setting_value():
pytest.importorskip('cupy')
arr_cpu = np.arange(25, dtype=np.float64).reshape(5, 5) arr_cpu = np.arange(25, dtype=np.float64).reshape(5, 5)
arr_gpu = cp.asarray(arr_cpu) arr_gpu = cp.asarray(arr_cpu)
...@@ -143,6 +149,7 @@ def test_setting_value(): ...@@ -143,6 +149,7 @@ def test_setting_value():
def test_periodicity(): def test_periodicity():
pytest.importorskip('cupy')
from pystencils.gpu.periodicity import get_periodic_boundary_functor as periodic_gpu from pystencils.gpu.periodicity import get_periodic_boundary_functor as periodic_gpu
from pystencils.slicing import get_periodic_boundary_functor as periodic_cpu from pystencils.slicing import get_periodic_boundary_functor as periodic_cpu
...@@ -163,6 +170,7 @@ def test_periodicity(): ...@@ -163,6 +170,7 @@ def test_periodicity():
@pytest.mark.parametrize("device_number", device_numbers) @pytest.mark.parametrize("device_number", device_numbers)
def test_block_indexing(device_number): def test_block_indexing(device_number):
pytest.importorskip('cupy')
f = fields("f: [3D]") f = fields("f: [3D]")
s = normalize_slice(make_slice[:, :, :], f.spatial_shape) s = normalize_slice(make_slice[:, :, :], f.spatial_shape)
bi = BlockIndexing(s, f.layout, block_size=(16, 8, 2), bi = BlockIndexing(s, f.layout, block_size=(16, 8, 2),
...@@ -195,6 +203,7 @@ def test_block_indexing(device_number): ...@@ -195,6 +203,7 @@ def test_block_indexing(device_number):
@pytest.mark.parametrize('layout', ("C", "F")) @pytest.mark.parametrize('layout', ("C", "F"))
@pytest.mark.parametrize('shape', ((5, 5, 5, 5), (3, 17, 387, 4), (23, 44, 21, 11))) @pytest.mark.parametrize('shape', ((5, 5, 5, 5), (3, 17, 387, 4), (23, 44, 21, 11)))
def test_four_dimensional_kernel(gpu_indexing, layout, shape): def test_four_dimensional_kernel(gpu_indexing, layout, shape):
pytest.importorskip('cupy')
n_elements = np.prod(shape) n_elements = np.prod(shape)
arr_cpu = np.arange(n_elements, dtype=np.float64).reshape(shape, order=layout) arr_cpu = np.arange(n_elements, dtype=np.float64).reshape(shape, order=layout)
...@@ -210,3 +219,39 @@ def test_four_dimensional_kernel(gpu_indexing, layout, shape): ...@@ -210,3 +219,39 @@ def test_four_dimensional_kernel(gpu_indexing, layout, shape):
kernel(f=arr_gpu, value=np.float64(42.0)) kernel(f=arr_gpu, value=np.float64(42.0))
np.testing.assert_equal(arr_gpu.get(), np.ones(shape) * 42.0) np.testing.assert_equal(arr_gpu.get(), np.ones(shape) * 42.0)
@pytest.mark.parametrize('start', (1, 5))
@pytest.mark.parametrize('end', (-1, -2, -3, -4))
@pytest.mark.parametrize('step', (1, 2, 3, 4))
@pytest.mark.parametrize('shape', ([55, 60], [77, 101, 80], [44, 64, 66]))
def test_guards_with_iteration_slices(start, end, step, shape):
iter_slice = tuple([slice(start, end, step)] * len(shape))
kernel_config_gpu = CreateKernelConfig(target=Target.GPU, iteration_slice=iter_slice)
field_1 = fields(f"f(1) : double{list(shape)}")
assignment = Assignment(field_1.center, 1)
ast = create_kernel(assignment, config=kernel_config_gpu)
code_str = get_code_str(ast)
test_strings = list()
iteration_ranges = list()
for i, s in enumerate(iter_slice):
e = ((shape[i] + end) - s.start) / s.step
e = math.ceil(e) + s.start
test_strings.append(f"{s.start} < {e}")
a = s.start
counter = 0
while a < e:
a += 1
counter += 1
iteration_ranges.append(counter)
# check if the expected if statement is in the GPU code
for s in test_strings:
assert s in code_str
# check if these bounds lead to same lengths as the range function would produce
for i in range(len(iter_slice)):
assert iteration_ranges[i] == len(range(iter_slice[i].start, shape[i] + end, iter_slice[i].step))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment