Skip to content
Snippets Groups Projects
Commit 1565ea2b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fixed test for sliced iteration with buffer to use dynamic field sizes

parent e442d084
No related merge requests found
......@@ -19,7 +19,7 @@ def _generate_fields(dt=np.uint64, num_directions=1, layout='numpy'):
fields = []
for size in field_sizes:
field_layout = layout_string_to_tuple(layout, len(size))
src_arr = create_numpy_array_with_layout(size, field_layout)
src_arr = create_numpy_array_with_layout(size, field_layout, dtype=dt)
array_data = np.reshape(np.arange(1, int(np.prod(size)+1)), size)
# Use flat iterator to input data into the array
......@@ -190,10 +190,14 @@ def test_field_layouts():
def test_iteration_slices():
num_cell_values = 19
fields = _generate_fields(num_directions=num_cell_values)
dt = np.uint64
fields = _generate_fields(dt=dt, num_directions=num_cell_values)
for (src_arr, dst_arr, bufferArr) in fields:
src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
spatial_dimensions = len(src_arr.shape) - 1
# src_field = Field.create_from_numpy_array("src_field", src_arr, index_dimensions=1)
# dst_field = Field.create_from_numpy_array("dst_field", dst_arr, index_dimensions=1)
src_field = Field.create_generic("src_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
dst_field = Field.create_generic("dst_field", spatial_dimensions, index_shape=(num_cell_values,), dtype=dt)
buffer = Field.create_generic("buffer", spatial_dimensions=1, index_dimensions=1,
field_type=FieldType.BUFFER, dtype=src_arr.dtype)
......@@ -211,7 +215,7 @@ def test_iteration_slices():
# Fill the entire array with data
src_arr[(slice(None, None, 1),) * dim] = np.arange(num_cell_values)
dst_arr.fill(0.0)
dst_arr.fill(0)
pack_code = create_kernel(pack_eqs, iteration_slice=pack_slice, data_type={'src_field': src_arr.dtype, 'buffer': buffer.dtype})
pack_kernel = pack_code.compile()
......@@ -229,6 +233,6 @@ def test_iteration_slices():
# Check if only every second entry of the leftmost slice has been copied
np.testing.assert_equal(dst_arr[pack_slice], src_arr[pack_slice])
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0.0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0.0)
np.testing.assert_equal(dst_arr[(slice(1, None, 2),) * (dim-1) + (0,)], 0)
np.testing.assert_equal(dst_arr[(slice(None, None, 1),) * (dim-1) + (slice(1,None),)], 0)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment