Skip to content
Snippets Groups Projects
Commit e4c6daa7 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'fix_buffer_slicing_test' into 'master'

Fixed test for sliced iteration with buffer to use dynamic field sizes

See merge request pycodegen/pystencils!280
parents e442d084 1565ea2b
No related branches found
No related tags found
1 merge request!280Fixed test for sliced iteration with buffer to use dynamic field sizes
Pipeline #35985 passed
......@@ -19,7 +19,7 @@ def _generate_fields(dt=np.uint64, num_directions=1, layout='numpy'):
fields = []
for size in field_sizes:
field_layout = layout_string_to_tuple(layout, len(size))
src_arr = create_numpy_array_with_layout(size, field_layout)
src_arr = create_numpy_array_with_layout(size, field_layout, dtype=dt)
array_data = np.reshape(np.arange(1, int(np.prod(size)+1)), size)
# Use flat iterator to input data into the array
......@@ -190,10 +190,14 @@ def test_field_layouts():
def test_iteration_slices():
num_cell_values = 19
fields = _generate_fields(num_directions=num_cell_values)
dt = np.uint64
fields = _generate_fields(dt=dt, num_directions=num_cell_values)
for (src_arr, dst_arr, bufferArr) in fields:
src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
spatial_dimensions = len(src_arr.shape) - 1
# src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
# dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
src_field = Field.create_generic("src_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
dst_field = Field.create_generic("dst_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1,
field_type=FieldType.BUFFER, dtype=src_arr.dtype)
......@@ -211,7 +215,7 @@ def test_iteration_slices():
# Fill the entire array with data
src_arr[(slice(None, None, 1),) * dim] = np.arange(num_cell_values)
dst_arr.fill(0.0)
dst_arr.fill(0)
pack_code = create_kernel(pack_eqs, iteration_slice=pack_slice, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype})
pack_kernel = pack_code.compile()
......@@ -229,6 +233,6 @@ def test_iteration_slices():
# Check if only every second entry of the leftmost slice has been copied
np.testing.assert_equal(dst_arr[pack_slice], src_arr[pack_slice])
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0.0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0.0)
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment