Skip to content
Snippets Groups Projects
Commit d2a2f47e authored by Rahil Doshi's avatar Rahil Doshi
Browse files

Update test_interpolation.cpp

parent f023410d
Branches
No related merge requests found
#include "interpolate_binary_search_cpp.h" #include "pymatlib_interpolators/interpolate_binary_search_cpp.h"
#include "interpolate_double_lookup_cpp.h" #include "pymatlib_interpolators/interpolate_double_lookup_cpp.h"
#include <array> #include <array>
#include <random> #include <random>
#include <chrono> #include <chrono>
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "TestArrayContainer.hpp" #include "TestArrayContainer.hpp"
// Helper function to compare floating point numbers // Helper function to compare floating point numbers
bool is_equal(double a, double b, double tolerance = 1e-10) { bool is_equal(const double a, const double b, double tolerance = 1e-10) {
return std::abs(a - b) < tolerance; return std::abs(a - b) < tolerance;
} }
...@@ -18,24 +18,24 @@ void test_basic_functionality() { ...@@ -18,24 +18,24 @@ void test_basic_functionality() {
std::cout << "\nTesting basic functionality..." << std::endl; std::cout << "\nTesting basic functionality..." << std::endl;
// Test middle point interpolation // Test middle point interpolation
const double test_E = 16950000000.0; constexpr double test_E = 16950000000.0;
// Binary Search Tests // Binary Search Tests
{ {
BinarySearchTests tests; constexpr BinarySearchTests tests;
double result = tests.interpolateBS(test_E); const double result = tests.interpolateBS(test_E);
// Expected value based on T_bs array // Expected value based on T_bs array
double expected_T = 3253.15; constexpr double expected_T = 3253.15;
assert(is_equal(result, expected_T)); assert(is_equal(result, expected_T));
std::cout << "Basic binary search interpolation passed" << std::endl; std::cout << "Basic binary search interpolation passed" << std::endl;
} }
// Double Lookup Tests // Double Lookup Tests
{ {
DoubleLookupTests tests; constexpr DoubleLookupTests tests;
double result = tests.interpolateDL(test_E); const double result = tests.interpolateDL(test_E);
// Expected value based on T_eq array // Expected value based on T_eq array
double expected_T = 3258.15; constexpr double expected_T = 3258.15;
assert(is_equal(result, expected_T)); assert(is_equal(result, expected_T));
std::cout << "Basic double lookup interpolation passed" << std::endl; std::cout << "Basic double lookup interpolation passed" << std::endl;
} }
...@@ -49,11 +49,11 @@ void test_edge_cases_and_errors() { ...@@ -49,11 +49,11 @@ void test_edge_cases_and_errors() {
BinarySearchTests tests; BinarySearchTests tests;
// Test below minimum // Test below minimum
double result_min = tests.interpolateBS(16750000000.0); const double result_min = tests.interpolateBS(16750000000.0);
assert(is_equal(result_min, 3243.15)); assert(is_equal(result_min, 3243.15));
// Test above maximum // Test above maximum
double result_max = tests.interpolateBS(17150000000.0); const double result_max = tests.interpolateBS(17150000000.0);
assert(is_equal(result_max, 3278.15)); assert(is_equal(result_max, 3278.15));
std::cout << "Edge cases for binary search passed" << std::endl; std::cout << "Edge cases for binary search passed" << std::endl;
...@@ -64,11 +64,11 @@ void test_edge_cases_and_errors() { ...@@ -64,11 +64,11 @@ void test_edge_cases_and_errors() {
DoubleLookupTests tests; DoubleLookupTests tests;
// Test below minimum // Test below minimum
double result_min = tests.interpolateDL(16750000000.0); const double result_min = tests.interpolateDL(16750000000.0);
assert(is_equal(result_min, 3243.15)); assert(is_equal(result_min, 3243.15));
// Test above maximum // Test above maximum
double result_max = tests.interpolateDL(17150000000.0); const double result_max = tests.interpolateDL(17150000000.0);
assert(is_equal(result_max, 3273.15)); assert(is_equal(result_max, 3273.15));
std::cout << "Edge cases for double lookup passed" << std::endl; std::cout << "Edge cases for double lookup passed" << std::endl;
...@@ -100,12 +100,12 @@ void test_interpolation_accuracy() { ...@@ -100,12 +100,12 @@ void test_interpolation_accuracy() {
for (const auto& test : test_cases) { for (const auto& test : test_cases) {
// Binary Search Tests // Binary Search Tests
BinarySearchTests bs_tests; BinarySearchTests bs_tests;
double result_bs = bs_tests.interpolateBS(test.input_E); const double result_bs = bs_tests.interpolateBS(test.input_E);
assert(is_equal(result_bs, test.expected_T_bs, test.tolerance)); assert(is_equal(result_bs, test.expected_T_bs, test.tolerance));
// Double Lookup Tests // Double Lookup Tests
DoubleLookupTests dl_tests; DoubleLookupTests dl_tests;
double result_dl = dl_tests.interpolateDL(test.input_E); const double result_dl = dl_tests.interpolateDL(test.input_E);
assert(is_equal(result_dl, test.expected_T_dl, test.tolerance)); assert(is_equal(result_dl, test.expected_T_dl, test.tolerance));
std::cout << "Accuracy test passed for " << test.description << std::endl; std::cout << "Accuracy test passed for " << test.description << std::endl;
...@@ -116,17 +116,17 @@ void test_consistency() { ...@@ -116,17 +116,17 @@ void test_consistency() {
std::cout << "\nTesting interpolation consistency..." << std::endl; std::cout << "\nTesting interpolation consistency..." << std::endl;
// Test that small changes in input produce correspondingly small changes in output // Test that small changes in input produce correspondingly small changes in output
BinarySearchTests bs_tests; constexpr BinarySearchTests bs_tests;
DoubleLookupTests dl_tests; constexpr DoubleLookupTests dl_tests;
const double base_E = 16900000000.0; constexpr double base_E = 16900000000.0;
const double delta_E = 1000000.0; // Small change in energy constexpr double delta_E = 1000000.0; // Small change in energy
double base_T_bs = bs_tests.interpolateBS(base_E); const double base_T_bs = bs_tests.interpolateBS(base_E);
double delta_T_bs = bs_tests.interpolateBS(base_E + delta_E) - base_T_bs; const double delta_T_bs = bs_tests.interpolateBS(base_E + delta_E) - base_T_bs;
double base_T_dl = dl_tests.interpolateDL(base_E); const double base_T_dl = dl_tests.interpolateDL(base_E);
double delta_T_dl = dl_tests.interpolateDL(base_E + delta_E) - base_T_dl; const double delta_T_dl = dl_tests.interpolateDL(base_E + delta_E) - base_T_dl;
// Check that changes are reasonable (not too large for small input change) // Check that changes are reasonable (not too large for small input change)
assert(std::abs(delta_T_bs) < 1.0); assert(std::abs(delta_T_bs) < 1.0);
...@@ -138,18 +138,18 @@ void test_consistency() { ...@@ -138,18 +138,18 @@ void test_consistency() {
void test_stress() { void test_stress() {
std::cout << "\nPerforming stress testing..." << std::endl; std::cout << "\nPerforming stress testing..." << std::endl;
BinarySearchTests bs_tests; constexpr BinarySearchTests bs_tests;
DoubleLookupTests dl_tests; constexpr DoubleLookupTests dl_tests;
// Test with extremely large values // Test with extremely large values
double large_E = 1.0e20; constexpr double large_E = 1.0e20;
double result_large_bs = bs_tests.interpolateBS(large_E); const double result_large_bs = bs_tests.interpolateBS(large_E);
double result_large_dl = dl_tests.interpolateDL(large_E); const double result_large_dl = dl_tests.interpolateDL(large_E);
// Test with extremely small values // Test with extremely small values
double small_E = 1.0e-20; constexpr double small_E = 1.0e-20;
double result_small_bs = bs_tests.interpolateBS(small_E); const double result_small_bs = bs_tests.interpolateBS(small_E);
double result_small_dl = dl_tests.interpolateDL(small_E); const double result_small_dl = dl_tests.interpolateDL(small_E);
// For extreme values, we should get boundary values // For extreme values, we should get boundary values
assert(is_equal(result_large_bs, 3278.15)); assert(is_equal(result_large_bs, 3278.15));
...@@ -163,19 +163,19 @@ void test_stress() { ...@@ -163,19 +163,19 @@ void test_stress() {
void test_random_validation() { void test_random_validation() {
std::cout << "\nTesting with random inputs..." << std::endl; std::cout << "\nTesting with random inputs..." << std::endl;
BinarySearchTests bs_tests; constexpr BinarySearchTests bs_tests;
DoubleLookupTests dl_tests; constexpr DoubleLookupTests dl_tests;
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_real_distribution<double> dist(16800000000.0, 17100000000.0); std::uniform_real_distribution<double> dist(16800000000.0, 17100000000.0);
const int num_tests = 1000; constexpr int num_tests = 1000;
for (int i = 0; i < num_tests; ++i) { for (int i = 0; i < num_tests; ++i) {
double random_E = dist(gen); const double random_E = dist(gen);
double result_bs = bs_tests.interpolateBS(random_E); const double result_bs = bs_tests.interpolateBS(random_E);
double result_dl = dl_tests.interpolateDL(random_E); const double result_dl = dl_tests.interpolateDL(random_E);
// Results should be within the temperature range // Results should be within the temperature range
assert(result_bs >= 3243.15 && result_bs <= 3278.15); assert(result_bs >= 3243.15 && result_bs <= 3278.15);
...@@ -195,14 +195,14 @@ void test_performance() { ...@@ -195,14 +195,14 @@ void test_performance() {
constexpr int numCells = 64*64*64; constexpr int numCells = 64*64*64;
// Setup test data // Setup test data
SS316L test; constexpr SS304L test;
std::vector<double> random_energies(numCells); std::vector<double> random_energies(numCells);
// Generate random values // Generate random values
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
const double E_min = SS316L::E_neq.front() * 0.8; constexpr double E_min = SS304L::E_neq.front() * 0.8;
const double E_max = SS316L::E_neq.back() * 1.2; constexpr double E_max = SS304L::E_neq.back() * 1.2;
std::uniform_real_distribution<double> dist(E_min, E_max); std::uniform_real_distribution<double> dist(E_min, E_max);
// Fill random energies // Fill random energies
...@@ -263,10 +263,10 @@ void test_performance() { ...@@ -263,10 +263,10 @@ void test_performance() {
// Calculate and print statistics // Calculate and print statistics
auto calc_stats = [](const std::vector<double>& timings) { auto calc_stats = [](const std::vector<double>& timings) {
double sum = std::accumulate(timings.begin(), timings.end(), 0.0); const double sum = std::accumulate(timings.begin(), timings.end(), 0.0);
double mean = sum / static_cast<double>(timings.size()); double mean = sum / static_cast<double>(timings.size());
double sq_sum = std::inner_product(timings.begin(), timings.end(), const double sq_sum = std::inner_product(timings.begin(), timings.end(),
timings.begin(), 0.0); timings.begin(), 0.0);
double stdev = std::sqrt(sq_sum / static_cast<double>(timings.size()) - mean * mean); double stdev = std::sqrt(sq_sum / static_cast<double>(timings.size()) - mean * mean);
return std::make_pair(mean, stdev); return std::make_pair(mean, stdev);
}; };
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment