Skip to content
Snippets Groups Projects
Commit 1d4c3bf6 authored by Frederik Hennig's avatar Frederik Hennig Committed by Markus Holzer
Browse files

Fixed test for sliced iteration with buffer to use dynamic field sizes

parent 5152e4e8
No related branches found
No related tags found
1 merge request!292Rebase of pystencils Type System
......@@ -20,7 +20,7 @@ def _generate_fields(dt=np.uint64, num_directions=1, layout='numpy'):
fields = []
for size in field_sizes:
field_layout = layout_string_to_tuple(layout, len(size))
src_arr = create_numpy_array_with_layout(size, field_layout)
src_arr = create_numpy_array_with_layout(size, field_layout, dtype=dt)
array_data = np.reshape(np.arange(1, int(np.prod(size)+1)), size)
# Use flat iterator to input data into the array
......@@ -193,10 +193,14 @@ def test_field_layouts():
def test_iteration_slices():
num_cell_values = 19
fields = _generate_fields(num_directions=num_cell_values)
dt = np.uint64
fields = _generate_fields(dt=dt, num_directions=num_cell_values)
for (src_arr, dst_arr, bufferArr) in fields:
src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
spatial_dimensions = len(src_arr.shape) - 1
# src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
# dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
src_field = Field.create_generic("src_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
dst_field = Field.create_generic("dst_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1,
field_type=FieldType.BUFFER, dtype=src_arr.dtype)
......@@ -214,7 +218,7 @@ def test_iteration_slices():
# Fill the entire array with data
src_arr[(slice(None, None, 1),) * dim] = np.arange(num_cell_values)
dst_arr.fill(0.0)
dst_arr.fill(0)
pack_code = create_kernel(pack_eqs, iteration_slice=pack_slice, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype})
pack_kernel = pack_code.compile()
......@@ -232,6 +236,6 @@ def test_iteration_slices():
# Check if only every second entry of the leftmost slice has been copied
np.testing.assert_equal(dst_arr[pack_slice], src_arr[pack_slice])
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0.0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0.0)
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment