From 7f76698d227f1b98b8c1bf0f8859a1b9243d1fa3 Mon Sep 17 00:00:00 2001
From: markus holzer <markus.holzer@fau.de>
Date: Sun, 9 Aug 2020 09:42:42 +0200
Subject: [PATCH] Added save and load test to datahandling tests

---
 .../datahandling/serial_datahandling.py       |   7 +--
 .../test_data/datahandling_load_test.npz      | Bin 0 -> 410 bytes
 .../test_data/datahandling_save_test.npz      | Bin 0 -> 410 bytes
 pystencils_tests/test_datahandling.py         |  41 ++++++++++++++++++
 4 files changed, 45 insertions(+), 3 deletions(-)
 create mode 100644 pystencils_tests/test_data/datahandling_load_test.npz
 create mode 100644 pystencils_tests/test_data/datahandling_save_test.npz

diff --git a/pystencils/datahandling/serial_datahandling.py b/pystencils/datahandling/serial_datahandling.py
index 25a4d23e..ce4629f6 100644
--- a/pystencils/datahandling/serial_datahandling.py
+++ b/pystencils/datahandling/serial_datahandling.py
@@ -425,14 +425,15 @@ class SerialDataHandling(DataHandling):
         np.savez_compressed(file, **self.cpu_arrays)
 
     def load_all(self, file):
+        if '.npz' not in file:
+            file += '.npz'
         file_contents = np.load(file)
         for arr_name, arr_contents in self.cpu_arrays.items():
             if arr_name not in file_contents:
                 print(f"Skipping read data {arr_name} because there is no data with this name in data handling")
                 continue
             if file_contents[arr_name].shape != arr_contents.shape:
-                print("Skipping read data {} because shapes don't match. "
-                      "Read array shape {}, existing array shape {}".format(arr_name, file_contents[arr_name].shape,
-                                                                            arr_contents.shape))
+                print(f"Skipping read data {arr_name} because shapes don't match. "
+                      f"Read array shape {file_contents[arr_name].shape}, existing array shape {arr_contents.shape}")
                 continue
             np.copyto(arr_contents, file_contents[arr_name])
diff --git a/pystencils_tests/test_data/datahandling_load_test.npz b/pystencils_tests/test_data/datahandling_load_test.npz
new file mode 100644
index 0000000000000000000000000000000000000000..d363a8a0aba1bb78a06314a19b887eb4c4975334
GIT binary patch
literal 410
zcmWIWW@Zs#U|`^2U|>*W=q~1y3SnSiU|?lnU}q3vC@xCY%PXj4WDo!gfn>na3=9mj
z--y4G7C3n#;8?)gd6S~%#4O2Mx*%_I+QM~<7tEU$9}+ZWhWPyWDU-N__%*eBS-Hg1
zRm7Gs?UYip;&Mn?xpHy9Y6rCkNes>W-xhz-u=VP{^ko?%e#fK~m(bTSj7%cTxWWb$
p6cF6N2%@1O72r**rd|dH2B?n$yjj^mW-u`@F=R0?Fi3$-1OO#vVK)E(

literal 0
HcmV?d00001

diff --git a/pystencils_tests/test_data/datahandling_save_test.npz b/pystencils_tests/test_data/datahandling_save_test.npz
new file mode 100644
index 0000000000000000000000000000000000000000..d363a8a0aba1bb78a06314a19b887eb4c4975334
GIT binary patch
literal 410
zcmWIWW@Zs#U|`^2U|>*W=q~1y3SnSiU|?lnU}q3vC@xCY%PXj4WDo!gfn>na3=9mj
z--y4G7C3n#;8?)gd6S~%#4O2Mx*%_I+QM~<7tEU$9}+ZWhWPyWDU-N__%*eBS-Hg1
zRm7Gs?UYip;&Mn?xpHy9Y6rCkNes>W-xhz-u=VP{^ko?%e#fK~m(bTSj7%cTxWWb$
p6cF6N2%@1O72r**rd|dH2B?n$yjj^mW-u`@F=R0?Fi3$-1OO#vVK)E(

literal 0
HcmV?d00001

diff --git a/pystencils_tests/test_datahandling.py b/pystencils_tests/test_datahandling.py
index 6e53d1e8..c18cfba9 100644
--- a/pystencils_tests/test_datahandling.py
+++ b/pystencils_tests/test_datahandling.py
@@ -310,3 +310,44 @@ def test_log():
     dh.log_on_root()
     assert dh.is_root
     assert dh.world_rank == 0
+
+
+def test_save_data():
+    domain_shape = (2, 2)
+
+    dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
+    dh.add_array("src", values_per_cell=9)
+    dh.fill("src", 1.0, ghost_layers=True)
+    dh.add_array("dst", values_per_cell=9)
+    dh.fill("dst", 1.0, ghost_layers=True)
+
+    dh.save_all('test_data/datahandling_save_test')
+
+
+def test_load_data():
+    domain_shape = (2, 2)
+
+    dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
+    dh.add_array("src", values_per_cell=9)
+    dh.fill("src", 0.0, ghost_layers=True)
+    dh.add_array("dst", values_per_cell=9)
+    dh.fill("dst", 0.0, ghost_layers=True)
+
+    dh.load_all('test_data/datahandling_load_test')
+    assert np.all(dh.cpu_arrays['src']) == 1
+    assert np.all(dh.cpu_arrays['dst']) == 1
+
+    domain_shape = (3, 3)
+
+    dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=1)
+    dh.add_array("src", values_per_cell=9)
+    dh.fill("src", 0.0, ghost_layers=True)
+    dh.add_array("dst", values_per_cell=9)
+    dh.fill("dst", 0.0, ghost_layers=True)
+    dh.add_array("dst2", values_per_cell=9)
+    dh.fill("dst2", 0.0, ghost_layers=True)
+
+    dh.load_all('test_data/datahandling_load_test')
+    assert np.all(dh.cpu_arrays['src']) == 0
+    assert np.all(dh.cpu_arrays['dst']) == 0
+    assert np.all(dh.cpu_arrays['dst2']) == 0
-- 
GitLab