From 1f716b641f4cf546f469d38943b4d2fdc5a9eb0a Mon Sep 17 00:00:00 2001
From: Rahil Doshi <rahil.doshi@fau.de>
Date: Wed, 19 Mar 2025 13:13:15 +0100
Subject: [PATCH] Improve interpolation code

---
 .../codegen/interpolation_array_container.py  | 50 ++++++++-----
 src/pymatlib/core/interpolators.py            | 71 +++++++++++++------
 2 files changed, 83 insertions(+), 38 deletions(-)

diff --git a/src/pymatlib/core/codegen/interpolation_array_container.py b/src/pymatlib/core/codegen/interpolation_array_container.py
index 1905294..240ea1c 100644
--- a/src/pymatlib/core/codegen/interpolation_array_container.py
+++ b/src/pymatlib/core/codegen/interpolation_array_container.py
@@ -1,3 +1,4 @@
+import re
 import numpy as np
 from pystencils.types import PsCustomType
 from pystencilssfg import SfgComposer
@@ -53,30 +54,45 @@ class InterpolationArrayContainer(CustomGenerator):
                 corresponding to temperature_array.
         Raises:
             ValueError: If arrays are empty, have different lengths, or are not monotonic.
+            TypeError: If name is not a string or arrays are not numpy arrays.
         """
         super().__init__()
+
+        # Validate inputs
+        if not isinstance(name, str) or not name:
+            raise TypeError("Name must be a non-empty string")
+
+        if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
+            raise ValueError(f"'{name}' is not a valid C++ class name")
+
+        if not isinstance(temperature_array, np.ndarray) or not isinstance(energy_density_array, np.ndarray):
+            raise TypeError("Temperature and energy arrays must be numpy arrays")
+
         self.name = name
         self.T_array = temperature_array
         self.E_array = energy_density_array
 
         # Prepare arrays and determine best method
-        self.data = prepare_interpolation_arrays(self.T_array, self.E_array)
-        self.method = self.data["method"]
-
-        # Store arrays for binary search (always available)
-        self.T_bs = self.data["T_bs"]
-        self.E_bs = self.data["E_bs"]
-
-        # Store arrays for double lookup if available
-        if self.method == "double_lookup":
-            self.T_eq = self.data["T_eq"]
-            self.E_neq = self.data["E_neq"]
-            self.E_eq = self.data["E_eq"]
-            self.inv_delta_E_eq = self.data["inv_delta_E_eq"]
-            self.idx_map = self.data["idx_map"]
-            self.has_double_lookup = True
-        else:
-            self.has_double_lookup = False
+        try:
+            self.data = prepare_interpolation_arrays(T_array=self.T_array, E_array=self.E_array, verbose=False)
+            self.method = self.data["method"]
+
+            # Store arrays for binary search (always available)
+            self.T_bs = self.data["T_bs"]
+            self.E_bs = self.data["E_bs"]
+
+            # Store arrays for double lookup if available
+            if self.method == "double_lookup":
+                self.T_eq = self.data["T_eq"]
+                self.E_neq = self.data["E_neq"]
+                self.E_eq = self.data["E_eq"]
+                self.inv_delta_E_eq = self.data["inv_delta_E_eq"]
+                self.idx_map = self.data["idx_map"]
+                self.has_double_lookup = True
+            else:
+                self.has_double_lookup = False
+        except Exception as e:
+            raise ValueError(f"Failed to prepare interpolation arrays: {e}") from e
 
     @classmethod
     def from_material(cls, name: str, material):
diff --git a/src/pymatlib/core/interpolators.py b/src/pymatlib/core/interpolators.py
index ad3c35f..64b6c98 100644
--- a/src/pymatlib/core/interpolators.py
+++ b/src/pymatlib/core/interpolators.py
@@ -323,10 +323,23 @@ def create_idx_mapping(E_neq: np.ndarray, E_eq: np.ndarray) -> np.ndarray:
     return idx_map.astype(np.int32)
 
 
-def prepare_interpolation_arrays(T_array: np.ndarray, E_array: np.ndarray) -> dict:
+def prepare_interpolation_arrays(T_array: np.ndarray, E_array: np.ndarray, verbose=False) -> dict:
     """
     Validates input arrays and prepares data for interpolation.
-    Returns a dictionary with arrays and metadata for the appropriate method.
+
+    Args:
+        T_array: Array of temperature values
+        E_array: Array of energy density values corresponding to temperatures
+        verbose: If True, prints diagnostic information during processing
+
+    Returns:
+        A dictionary with arrays and metadata for the appropriate interpolation method:
+        - Common keys: 'T_bs', 'E_bs', 'method', 'is_equidistant', 'increment'
+        - Additional keys for double_lookup method: 'T_eq', 'E_neq', 'E_eq',
+          'inv_delta_E_eq', 'idx_map'
+
+    Raises:
+        ValueError: If arrays don't meet requirements for interpolation
     """
     # Convert to numpy arrays if not already
     T_array = np.asarray(T_array)
@@ -364,43 +377,59 @@ def prepare_interpolation_arrays(T_array: np.ndarray, E_array: np.ndarray) -> di
 
     # Flip arrays if temperature is in descending order
     if T_decreasing:
-        print("Temperature array is descending, flipping arrays for processing")
+        if verbose:
+            print("Temperature array is descending, flipping arrays for processing")
         T_bs = np.flip(T_bs)
         E_bs = np.flip(E_bs)
 
+    # Check for strictly increasing values
+    has_warnings = False
     try:
         check_strictly_increasing(T_bs, "Temperature array")
         check_strictly_increasing(E_bs, "Energy density array")
     except ValueError as e:
-        print(f"Warning: {e}")
-        print("Continuing with interpolation, but results may be less accurate")
+        has_warnings = True
+        if verbose:
+            print(f"Warning: {e}")
+            print("Continuing with interpolation, but results may be less accurate")
 
-    # Use your existing check_equidistant function to determine if suitable for double lookup
+    # Use the existing check_equidistant function to determine if suitable for double lookup
     T_incr = check_equidistant(T_bs)
     is_equidistant = T_incr != 0.0
 
+    # Initialize result with common fields
     result = {
         "T_bs": T_bs,
         "E_bs": E_bs,
-        "method": "binary_search"
+        "method": "binary_search",
+        "is_equidistant": is_equidistant,
+        "increment": T_incr if is_equidistant else 0.0,
+        "has_warnings": has_warnings
     }
 
     # If temperature is equidistant, prepare for double lookup
     if is_equidistant:
-        print(f"Temperature array is equidistant with increment {T_incr}, using double lookup")
-        # Create equidistant energy array and mapping
-        E_eq, inv_delta_E_eq = E_eq_from_E_neq(E_bs)
-        idx_mapping = create_idx_mapping(E_bs, E_eq)
-
-        result.update({
-            "T_eq": T_bs,
-            "E_neq": E_bs,
-            "E_eq": E_eq,
-            "inv_delta_E_eq": inv_delta_E_eq,
-            "idx_map": idx_mapping,
-            "method": "double_lookup"
-        })
-    else:
+        if verbose:
+            print(f"Temperature array is equidistant with increment {T_incr}, using double lookup")
+
+        try:
+            # Create equidistant energy array and mapping
+            E_eq, inv_delta_E_eq = E_eq_from_E_neq(E_bs)
+            idx_mapping = create_idx_mapping(E_bs, E_eq)
+
+            result.update({
+                "T_eq": T_bs,
+                "E_neq": E_bs,
+                "E_eq": E_eq,
+                "inv_delta_E_eq": inv_delta_E_eq,
+                "idx_map": idx_mapping,
+                "method": "double_lookup"
+            })
+        except Exception as e:
+            if verbose:
+                print(f"Warning: Failed to create double lookup tables: {e}")
+                print("Falling back to binary search method")
+    elif verbose:
         print("Temperature array is not equidistant, using binary search")
 
     return result
-- 
GitLab