diff --git a/cbutil/__init__.py b/cbutil/__init__.py index 91c8e0a270f9299a76db9dac1aa60e487277f1b3..310a18d7af56b801438885462a9ff0a2e66befe5 100644 --- a/cbutil/__init__.py +++ b/cbutil/__init__.py @@ -6,3 +6,4 @@ from .data_points import DataPoint, data_point_factory from .get_job_info import get_url_from_env, get_job_datapoints from .gitlab_api import get_git_infos_from_api from .update_data import get_updated_data +from .compare_vtu import compare_vtu_files diff --git a/cbutil/compare_vtu.py b/cbutil/compare_vtu.py new file mode 100644 index 0000000000000000000000000000000000000000..9d39abe42f17e9c655f358afa214168ebb565aad --- /dev/null +++ b/cbutil/compare_vtu.py @@ -0,0 +1,39 @@ +import xml.dom.minidom +import logging +import numpy as np +logger = logging.getLogger(__name__) + + +def read_vtu_file(file_name): + try: + return xml.dom.minidom.parse(file_name) + except Exception as e: + logger.error(f"Error parsing {file_name}: {str(e)}") + raise + + +def extract_data_array(output, name): + data_arrays = output.getElementsByTagName("DataArray") + data_array = next( + (da for da in data_arrays if da.attributes['Name'].value == name), None) + if data_array is None: + logging.error(f"Data array '{name}' not found in VTU file.") + return None + floats = [float(n) for n in data_array.firstChild.data.strip().split()] + return np.array(floats, dtype=data_array.attributes["type"].value.lower()) + + +def calculate_L2_norm(ref, data): + return np.linalg.norm(ref - data) + + +def compare_vtu_files(ref_file, result_file, arrays): + ref_output = read_vtu_file(ref_file) + result_output = read_vtu_file(result_file) + differences = {} + for array in arrays: + ref = extract_data_array(ref_output, array) + result = extract_data_array(result_output, array) + if ref is not None and result is not None: + differences[f'{array}_L2_diff'] = calculate_L2_norm(ref, result) + return differences diff --git a/pyproject.toml b/pyproject.toml index e114aaa9a6310a13a25e1ceff3ec14d887c80298..50b29e4cbc5457f2da0686c9d2998b3abc518809 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "gitpython", "requests", "kadi-apy", + "numpy", "importlib_resources ; python_version<'3.7'", ] diff --git a/tests/test_compare_vtu.py b/tests/test_compare_vtu.py new file mode 100644 index 0000000000000000000000000000000000000000..09f21e14a41c1e0db75b2907e38482a2da0ce7c4 --- /dev/null +++ b/tests/test_compare_vtu.py @@ -0,0 +1,56 @@ +import os +import pytest +import tempfile +from cbutil.compare_vtu import compare_vtu_files + + +@pytest.fixture +def setup_test_files(): + # Create a reference VTU file and a result VTU file for testing + + # Assuming the reference VTU file contains data arrays 'U1', 'U2', and 'U3' + # with values [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], and [7.0, 8.0, 9.0], respectively. + # The result VTU file contains identical data arrays. + + # Create the reference VTU file + with tempfile.TemporaryDirectory() as tmp_dir: + ref_file = f"{tmp_dir}/reference_test_file.vtu" + result_file = f"{tmp_dir}/result_test_file.vtu" + with open(ref_file, "w") as f: + f.write('<?xml version="1.0"?>\n<VTKFile type="UnstructuredGrid" version="0.1" byte_order="LittleEndian">\n') + f.write('<UnstructuredGrid>\n<PointData>\n') + for array_name, values in zip(['U1', 'U2', 'U3'], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]): + f.write(f'<DataArray type="Float64" Name="{array_name}" format="ascii">\n') + f.write(" ".join(map(str, values)) + "\n") + f.write("</DataArray>\n") + f.write('</PointData>\n</UnstructuredGrid>\n</VTKFile>\n') + + # Create the result VTU file (identical to the reference) + with open(result_file, "w") as f: + f.write('<?xml version="1.0"?>\n<VTKFile type="UnstructuredGrid" version="0.1" byte_order="LittleEndian">\n') + f.write('<UnstructuredGrid>\n<PointData>\n') + for array_name, values in zip(['U1', 'U2', 'U3'], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]): + f.write(f'<DataArray type="Float64" Name="{array_name}" format="ascii">\n') + f.write(" ".join(map(str, values)) + "\n") + f.write("</DataArray>\n") + f.write('</PointData>\n</UnstructuredGrid>\n</VTKFile>\n') + yield ref_file, result_file + + +def test_valid_comparison(setup_test_files): + ref_file, result_file = setup_test_files + result = compare_vtu_files(ref_file, result_file, ['U1', 'U2', 'U3']) + assert result['U1_L2_diff'] == pytest.approx(0.0) + assert result['U2_L2_diff'] == pytest.approx(0.0) + assert result['U3_L2_diff'] == pytest.approx(0.0) + + +def test_missing_reference_file(): + with pytest.raises(FileNotFoundError): + compare_vtu_files("non_existent_file.vtu", "other_file.vtu", ["U1"]) + + +def test_missing_data_array(setup_test_files): + ref_file, result_file = setup_test_files + result = compare_vtu_files(ref_file, result_file, ['U1', 'U2', 'U3', 'U4']) + assert 'U4_L2_diff' not in result