Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pymatlib
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Rahil Doshi
pymatlib
Commits
d2a2f47e
Commit
d2a2f47e
authored
4 months ago
by
Rahil Doshi
Browse files
Options
Downloads
Patches
Plain Diff
Update test_interpolation.cpp
parent
f023410d
Branches
Branches containing commit
Tags
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
tests/cpp/test_interpolation.cpp
+44
-44
44 additions, 44 deletions
tests/cpp/test_interpolation.cpp
with
44 additions
and
44 deletions
tests/cpp/test_interpolation.cpp
+
44
−
44
View file @
d2a2f47e
#include
"interpolate_binary_search_cpp.h"
#include
"
pymatlib_interpolators/
interpolate_binary_search_cpp.h"
#include
"interpolate_double_lookup_cpp.h"
#include
"
pymatlib_interpolators/
interpolate_double_lookup_cpp.h"
#include
<array>
#include
<array>
#include
<random>
#include
<random>
#include
<chrono>
#include
<chrono>
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include
"TestArrayContainer.hpp"
#include
"TestArrayContainer.hpp"
// Helper function to compare floating point numbers
// Helper function to compare floating point numbers
bool
is_equal
(
double
a
,
double
b
,
double
tolerance
=
1e-10
)
{
bool
is_equal
(
const
double
a
,
const
double
b
,
double
tolerance
=
1e-10
)
{
return
std
::
abs
(
a
-
b
)
<
tolerance
;
return
std
::
abs
(
a
-
b
)
<
tolerance
;
}
}
...
@@ -18,24 +18,24 @@ void test_basic_functionality() {
...
@@ -18,24 +18,24 @@ void test_basic_functionality() {
std
::
cout
<<
"
\n
Testing basic functionality..."
<<
std
::
endl
;
std
::
cout
<<
"
\n
Testing basic functionality..."
<<
std
::
endl
;
// Test middle point interpolation
// Test middle point interpolation
const
double
test_E
=
16950000000.0
;
const
expr
double
test_E
=
16950000000.0
;
// Binary Search Tests
// Binary Search Tests
{
{
BinarySearchTests
tests
;
constexpr
BinarySearchTests
tests
;
double
result
=
tests
.
interpolateBS
(
test_E
);
const
double
result
=
tests
.
interpolateBS
(
test_E
);
// Expected value based on T_bs array
// Expected value based on T_bs array
double
expected_T
=
3253.15
;
constexpr
double
expected_T
=
3253.15
;
assert
(
is_equal
(
result
,
expected_T
));
assert
(
is_equal
(
result
,
expected_T
));
std
::
cout
<<
"Basic binary search interpolation passed"
<<
std
::
endl
;
std
::
cout
<<
"Basic binary search interpolation passed"
<<
std
::
endl
;
}
}
// Double Lookup Tests
// Double Lookup Tests
{
{
DoubleLookupTests
tests
;
constexpr
DoubleLookupTests
tests
;
double
result
=
tests
.
interpolateDL
(
test_E
);
const
double
result
=
tests
.
interpolateDL
(
test_E
);
// Expected value based on T_eq array
// Expected value based on T_eq array
double
expected_T
=
3258.15
;
constexpr
double
expected_T
=
3258.15
;
assert
(
is_equal
(
result
,
expected_T
));
assert
(
is_equal
(
result
,
expected_T
));
std
::
cout
<<
"Basic double lookup interpolation passed"
<<
std
::
endl
;
std
::
cout
<<
"Basic double lookup interpolation passed"
<<
std
::
endl
;
}
}
...
@@ -49,11 +49,11 @@ void test_edge_cases_and_errors() {
...
@@ -49,11 +49,11 @@ void test_edge_cases_and_errors() {
BinarySearchTests
tests
;
BinarySearchTests
tests
;
// Test below minimum
// Test below minimum
double
result_min
=
tests
.
interpolateBS
(
16750000000.0
);
const
double
result_min
=
tests
.
interpolateBS
(
16750000000.0
);
assert
(
is_equal
(
result_min
,
3243.15
));
assert
(
is_equal
(
result_min
,
3243.15
));
// Test above maximum
// Test above maximum
double
result_max
=
tests
.
interpolateBS
(
17150000000.0
);
const
double
result_max
=
tests
.
interpolateBS
(
17150000000.0
);
assert
(
is_equal
(
result_max
,
3278.15
));
assert
(
is_equal
(
result_max
,
3278.15
));
std
::
cout
<<
"Edge cases for binary search passed"
<<
std
::
endl
;
std
::
cout
<<
"Edge cases for binary search passed"
<<
std
::
endl
;
...
@@ -64,11 +64,11 @@ void test_edge_cases_and_errors() {
...
@@ -64,11 +64,11 @@ void test_edge_cases_and_errors() {
DoubleLookupTests
tests
;
DoubleLookupTests
tests
;
// Test below minimum
// Test below minimum
double
result_min
=
tests
.
interpolateDL
(
16750000000.0
);
const
double
result_min
=
tests
.
interpolateDL
(
16750000000.0
);
assert
(
is_equal
(
result_min
,
3243.15
));
assert
(
is_equal
(
result_min
,
3243.15
));
// Test above maximum
// Test above maximum
double
result_max
=
tests
.
interpolateDL
(
17150000000.0
);
const
double
result_max
=
tests
.
interpolateDL
(
17150000000.0
);
assert
(
is_equal
(
result_max
,
3273.15
));
assert
(
is_equal
(
result_max
,
3273.15
));
std
::
cout
<<
"Edge cases for double lookup passed"
<<
std
::
endl
;
std
::
cout
<<
"Edge cases for double lookup passed"
<<
std
::
endl
;
...
@@ -100,12 +100,12 @@ void test_interpolation_accuracy() {
...
@@ -100,12 +100,12 @@ void test_interpolation_accuracy() {
for
(
const
auto
&
test
:
test_cases
)
{
for
(
const
auto
&
test
:
test_cases
)
{
// Binary Search Tests
// Binary Search Tests
BinarySearchTests
bs_tests
;
BinarySearchTests
bs_tests
;
double
result_bs
=
bs_tests
.
interpolateBS
(
test
.
input_E
);
const
double
result_bs
=
bs_tests
.
interpolateBS
(
test
.
input_E
);
assert
(
is_equal
(
result_bs
,
test
.
expected_T_bs
,
test
.
tolerance
));
assert
(
is_equal
(
result_bs
,
test
.
expected_T_bs
,
test
.
tolerance
));
// Double Lookup Tests
// Double Lookup Tests
DoubleLookupTests
dl_tests
;
DoubleLookupTests
dl_tests
;
double
result_dl
=
dl_tests
.
interpolateDL
(
test
.
input_E
);
const
double
result_dl
=
dl_tests
.
interpolateDL
(
test
.
input_E
);
assert
(
is_equal
(
result_dl
,
test
.
expected_T_dl
,
test
.
tolerance
));
assert
(
is_equal
(
result_dl
,
test
.
expected_T_dl
,
test
.
tolerance
));
std
::
cout
<<
"Accuracy test passed for "
<<
test
.
description
<<
std
::
endl
;
std
::
cout
<<
"Accuracy test passed for "
<<
test
.
description
<<
std
::
endl
;
...
@@ -116,17 +116,17 @@ void test_consistency() {
...
@@ -116,17 +116,17 @@ void test_consistency() {
std
::
cout
<<
"
\n
Testing interpolation consistency..."
<<
std
::
endl
;
std
::
cout
<<
"
\n
Testing interpolation consistency..."
<<
std
::
endl
;
// Test that small changes in input produce correspondingly small changes in output
// Test that small changes in input produce correspondingly small changes in output
BinarySearchTests
bs_tests
;
constexpr
BinarySearchTests
bs_tests
;
DoubleLookupTests
dl_tests
;
constexpr
DoubleLookupTests
dl_tests
;
const
double
base_E
=
16900000000.0
;
const
expr
double
base_E
=
16900000000.0
;
const
double
delta_E
=
1000000.0
;
// Small change in energy
const
expr
double
delta_E
=
1000000.0
;
// Small change in energy
double
base_T_bs
=
bs_tests
.
interpolateBS
(
base_E
);
const
double
base_T_bs
=
bs_tests
.
interpolateBS
(
base_E
);
double
delta_T_bs
=
bs_tests
.
interpolateBS
(
base_E
+
delta_E
)
-
base_T_bs
;
const
double
delta_T_bs
=
bs_tests
.
interpolateBS
(
base_E
+
delta_E
)
-
base_T_bs
;
double
base_T_dl
=
dl_tests
.
interpolateDL
(
base_E
);
const
double
base_T_dl
=
dl_tests
.
interpolateDL
(
base_E
);
double
delta_T_dl
=
dl_tests
.
interpolateDL
(
base_E
+
delta_E
)
-
base_T_dl
;
const
double
delta_T_dl
=
dl_tests
.
interpolateDL
(
base_E
+
delta_E
)
-
base_T_dl
;
// Check that changes are reasonable (not too large for small input change)
// Check that changes are reasonable (not too large for small input change)
assert
(
std
::
abs
(
delta_T_bs
)
<
1.0
);
assert
(
std
::
abs
(
delta_T_bs
)
<
1.0
);
...
@@ -138,18 +138,18 @@ void test_consistency() {
...
@@ -138,18 +138,18 @@ void test_consistency() {
void
test_stress
()
{
void
test_stress
()
{
std
::
cout
<<
"
\n
Performing stress testing..."
<<
std
::
endl
;
std
::
cout
<<
"
\n
Performing stress testing..."
<<
std
::
endl
;
BinarySearchTests
bs_tests
;
constexpr
BinarySearchTests
bs_tests
;
DoubleLookupTests
dl_tests
;
constexpr
DoubleLookupTests
dl_tests
;
// Test with extremely large values
// Test with extremely large values
double
large_E
=
1.0e20
;
constexpr
double
large_E
=
1.0e20
;
double
result_large_bs
=
bs_tests
.
interpolateBS
(
large_E
);
const
double
result_large_bs
=
bs_tests
.
interpolateBS
(
large_E
);
double
result_large_dl
=
dl_tests
.
interpolateDL
(
large_E
);
const
double
result_large_dl
=
dl_tests
.
interpolateDL
(
large_E
);
// Test with extremely small values
// Test with extremely small values
double
small_E
=
1.0e-20
;
constexpr
double
small_E
=
1.0e-20
;
double
result_small_bs
=
bs_tests
.
interpolateBS
(
small_E
);
const
double
result_small_bs
=
bs_tests
.
interpolateBS
(
small_E
);
double
result_small_dl
=
dl_tests
.
interpolateDL
(
small_E
);
const
double
result_small_dl
=
dl_tests
.
interpolateDL
(
small_E
);
// For extreme values, we should get boundary values
// For extreme values, we should get boundary values
assert
(
is_equal
(
result_large_bs
,
3278.15
));
assert
(
is_equal
(
result_large_bs
,
3278.15
));
...
@@ -163,19 +163,19 @@ void test_stress() {
...
@@ -163,19 +163,19 @@ void test_stress() {
void
test_random_validation
()
{
void
test_random_validation
()
{
std
::
cout
<<
"
\n
Testing with random inputs..."
<<
std
::
endl
;
std
::
cout
<<
"
\n
Testing with random inputs..."
<<
std
::
endl
;
BinarySearchTests
bs_tests
;
constexpr
BinarySearchTests
bs_tests
;
DoubleLookupTests
dl_tests
;
constexpr
DoubleLookupTests
dl_tests
;
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
std
::
uniform_real_distribution
<
double
>
dist
(
16800000000.0
,
17100000000.0
);
std
::
uniform_real_distribution
<
double
>
dist
(
16800000000.0
,
17100000000.0
);
const
int
num_tests
=
1000
;
const
expr
int
num_tests
=
1000
;
for
(
int
i
=
0
;
i
<
num_tests
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_tests
;
++
i
)
{
double
random_E
=
dist
(
gen
);
const
double
random_E
=
dist
(
gen
);
double
result_bs
=
bs_tests
.
interpolateBS
(
random_E
);
const
double
result_bs
=
bs_tests
.
interpolateBS
(
random_E
);
double
result_dl
=
dl_tests
.
interpolateDL
(
random_E
);
const
double
result_dl
=
dl_tests
.
interpolateDL
(
random_E
);
// Results should be within the temperature range
// Results should be within the temperature range
assert
(
result_bs
>=
3243.15
&&
result_bs
<=
3278.15
);
assert
(
result_bs
>=
3243.15
&&
result_bs
<=
3278.15
);
...
@@ -195,14 +195,14 @@ void test_performance() {
...
@@ -195,14 +195,14 @@ void test_performance() {
constexpr
int
numCells
=
64
*
64
*
64
;
constexpr
int
numCells
=
64
*
64
*
64
;
// Setup test data
// Setup test data
SS3
16
L
test
;
constexpr
SS3
04
L
test
;
std
::
vector
<
double
>
random_energies
(
numCells
);
std
::
vector
<
double
>
random_energies
(
numCells
);
// Generate random values
// Generate random values
std
::
random_device
rd
;
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
mt19937
gen
(
rd
());
const
double
E_min
=
SS3
16
L
::
E_neq
.
front
()
*
0.8
;
const
expr
double
E_min
=
SS3
04
L
::
E_neq
.
front
()
*
0.8
;
const
double
E_max
=
SS3
16
L
::
E_neq
.
back
()
*
1.2
;
const
expr
double
E_max
=
SS3
04
L
::
E_neq
.
back
()
*
1.2
;
std
::
uniform_real_distribution
<
double
>
dist
(
E_min
,
E_max
);
std
::
uniform_real_distribution
<
double
>
dist
(
E_min
,
E_max
);
// Fill random energies
// Fill random energies
...
@@ -263,10 +263,10 @@ void test_performance() {
...
@@ -263,10 +263,10 @@ void test_performance() {
// Calculate and print statistics
// Calculate and print statistics
auto
calc_stats
=
[](
const
std
::
vector
<
double
>&
timings
)
{
auto
calc_stats
=
[](
const
std
::
vector
<
double
>&
timings
)
{
double
sum
=
std
::
accumulate
(
timings
.
begin
(),
timings
.
end
(),
0.0
);
const
double
sum
=
std
::
accumulate
(
timings
.
begin
(),
timings
.
end
(),
0.0
);
double
mean
=
sum
/
static_cast
<
double
>
(
timings
.
size
());
double
mean
=
sum
/
static_cast
<
double
>
(
timings
.
size
());
double
sq_sum
=
std
::
inner_product
(
timings
.
begin
(),
timings
.
end
(),
const
double
sq_sum
=
std
::
inner_product
(
timings
.
begin
(),
timings
.
end
(),
timings
.
begin
(),
0.0
);
timings
.
begin
(),
0.0
);
double
stdev
=
std
::
sqrt
(
sq_sum
/
static_cast
<
double
>
(
timings
.
size
())
-
mean
*
mean
);
double
stdev
=
std
::
sqrt
(
sq_sum
/
static_cast
<
double
>
(
timings
.
size
())
-
mean
*
mean
);
return
std
::
make_pair
(
mean
,
stdev
);
return
std
::
make_pair
(
mean
,
stdev
);
};
};
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment