Skip to content
Snippets Groups Projects
Commit 7f76698d authored by Markus Holzer's avatar Markus Holzer
Browse files

Added save and load test to datahandling tests

parent d0a06963
No related branches found
No related tags found
No related merge requests found
......@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
np.savez_compressed(file, **self.cpu_arrays)
def load_all(self, file):
if '.npz' not in file:
file += '.npz'
file_contents = np.load(file)
for arr_name, arr_contents in self.cpu_arrays.items():
if arr_name not in file_contents:
print(f"Skipping read data {arr_name} because there is no data with this name in data handling")
continue
if file_contents[arr_name].shape != arr_contents.shape:
print("Skipping read data {} because shapes don't match. "
"Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape,
arr_contents.shape))
print(f"Skipping read data {arr_name} because shapes don't match. "
f"Read array shape {file_contents[arr_name].shape}, existing array shape {arr_contents.shape}")
continue
np.copyto(arr_contents, file_contents[arr_name])
File added
File added
......@@ -310,3 +310,44 @@ def test_log():
dh.log_on_root()
assert dh.is_root
assert dh.world_rank == 0
def test_save_data():
domain_shape = (2, 2)
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
dh.add_array("src", values_per_cell=9)
dh.fill("src", 1.0, ghost_layers=True)
dh.add_array("dst", values_per_cell=9)
dh.fill("dst", 1.0, ghost_layers=True)
dh.save_all('test_data/datahandling_save_test')
def test_load_data():
domain_shape = (2, 2)
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
dh.add_array("src", values_per_cell=9)
dh.fill("src", 0.0, ghost_layers=True)
dh.add_array("dst", values_per_cell=9)
dh.fill("dst", 0.0, ghost_layers=True)
dh.load_all('test_data/datahandling_load_test')
assert np.all(dh.cpu_arrays['src']) == 1
assert np.all(dh.cpu_arrays['dst']) == 1
domain_shape = (3, 3)
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
dh.add_array("src", values_per_cell=9)
dh.fill("src", 0.0, ghost_layers=True)
dh.add_array("dst", values_per_cell=9)
dh.fill("dst", 0.0, ghost_layers=True)
dh.add_array("dst2", values_per_cell=9)
dh.fill("dst2", 0.0, ghost_layers=True)
dh.load_all('test_data/datahandling_load_test')
assert np.all(dh.cpu_arrays['src']) == 0
assert np.all(dh.cpu_arrays['dst']) == 0
assert np.all(dh.cpu_arrays['dst2']) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment