From 155510e7c63473092ca0d3692bd0946bdf8f9ec5 Mon Sep 17 00:00:00 2001
From: Rahil Doshi <rahil.doshi@fau.de>
Date: Mon, 10 Mar 2025 17:20:23 +0100
Subject: [PATCH] Refactor InterpolationArrayContainer

---
 src/pymatlib/core/interpolators.py | 133 +++++++++++++++++++++++------
 1 file changed, 108 insertions(+), 25 deletions(-)

diff --git a/src/pymatlib/core/interpolators.py b/src/pymatlib/core/interpolators.py
index e83415d..f41df0e 100644
--- a/src/pymatlib/core/interpolators.py
+++ b/src/pymatlib/core/interpolators.py
@@ -13,7 +13,54 @@ COUNT = 0
 
 
 class InterpolationArrayContainer(CustomGenerator):
+    """Container for energy-temperature interpolation arrays and methods.
+
+    This class stores temperature and energy density arrays and generates C++ code
+    for efficient bilateral conversion between these properties. It supports both
+    binary search interpolation (O(log n)) and double lookup interpolation (O(1))
+    with automatic method selection based on data characteristics.
+
+    Attributes:
+        name (str): Name for the generated C++ class.
+        T_array (np.ndarray): Array of temperature values (must be monotonically increasing).
+        E_array (np.ndarray): Array of energy density values corresponding to T_array.
+        method (str): Interpolation method selected ("binary_search" or "double_lookup").
+        T_bs (np.ndarray): Temperature array prepared for binary search.
+        E_bs (np.ndarray): Energy array prepared for binary search.
+        has_double_lookup (bool): Whether double lookup interpolation is available.
+
+    If has_double_lookup is True, the following attributes are also available:
+        T_eq (np.ndarray): Equidistant temperature array for double lookup.
+        E_neq (np.ndarray): Non-equidistant energy array for double lookup.
+        E_eq (np.ndarray): Equidistant energy array for double lookup.
+        inv_delta_E_eq (float): Inverse of the energy step size for double lookup.
+        idx_map (np.ndarray): Index mapping array for double lookup.
+
+    Examples:
+        >>> import numpy as np
+        >>> from pystencils_sfg import SfgComposer
+        >>> from pymatlib.core.interpolators import InterpolationArrayContainer
+        >>>
+        >>> # Create temperature and energy arrays
+        >>> T = np.array([300, 600, 900, 1200], dtype=np.float64)
+        >>> E = np.array([1e9, 2e9, 3e9, 4e9], dtype=np.float64)
+        >>>
+        >>> # Create and generate the container
+        >>> with SfgComposer() as sfg:
+        >>>     container = InterpolationArrayContainer("MyMaterial", T, E)
+        >>>     sfg.generate(container)
+    """
     def __init__(self, name: str, temperature_array: np.ndarray, energy_density_array: np.ndarray):
+        """Initialize the interpolation container.
+        Args:
+            name (str): Name for the generated C++ class.
+            temperature_array (np.ndarray): Array of temperature values (K).
+                Must be monotonically increasing.
+            energy_density_array (np.ndarray): Array of energy density values (J/m³)
+                corresponding to temperature_array.
+        Raises:
+            ValueError: If arrays are empty, have different lengths, or are not monotonic.
+        """
         super().__init__()
         self.name = name
         self.T_array = temperature_array
@@ -40,19 +87,29 @@ class InterpolationArrayContainer(CustomGenerator):
 
     @classmethod
     def from_material(cls, name: str, material):
+        """Create an interpolation container from a material object.
+        Args:
+            name (str): Name for the generated C++ class.
+            material: Material object with temperature and energy properties.
+                Must have energy_density_temperature_array and energy_density_array attributes.
+        Returns:
+            InterpolationArrayContainer: Container with arrays for interpolation.
+        """
         return cls(name, material.energy_density_temperature_array, material.energy_density_array)
 
-    def generate(self, sfg: SfgComposer):
-        sfg.include("<array>")
-        sfg.include("pymatlib_interpolators/interpolate_binary_search_cpp.h")
-
-        # Binary search arrays (always included)
+    def _generate_binary_search(self, sfg: SfgComposer):
+        """Generate code for binary search interpolation.
+        Args:
+            sfg (SfgComposer): Source file generator composer.
+        Returns:
+            list: List of public members for the C++ class.
+        """
         T_bs_arr_values = ", ".join(str(v) for v in self.T_bs)
         E_bs_arr_values = ", ".join(str(v) for v in self.E_bs)
 
         E_target = sfg.var("E_target", "double")
 
-        public_members = [
+        return [
             # Binary search arrays
             f"static constexpr std::array< double, {self.T_bs.shape[0]} > T_bs {{ {T_bs_arr_values} }}; \n"
             f"static constexpr std::array< double, {self.E_bs.shape[0]} > E_bs {{ {E_bs_arr_values} }}; \n",
@@ -63,30 +120,56 @@ class InterpolationArrayContainer(CustomGenerator):
             )
         ]
 
+    def _generate_double_lookup(self, sfg: SfgComposer):
+        """Generate code for double lookup interpolation.
+        Args:
+            sfg (SfgComposer): Source file generator composer.
+        Returns:
+            list: List of public members for the C++ class.
+        """
+        if not self.has_double_lookup:
+            return []
+
+        T_eq_arr_values = ", ".join(str(v) for v in self.T_eq)
+        E_neq_arr_values = ", ".join(str(v) for v in self.E_neq)
+        E_eq_arr_values = ", ".join(str(v) for v in self.E_eq)
+        idx_mapping_arr_values = ", ".join(str(v) for v in self.idx_map)
+
+        E_target = sfg.var("E_target", "double")
+
+        return [
+            # Double lookup arrays
+            f"static constexpr std::array< double, {self.T_eq.shape[0]} > T_eq {{ {T_eq_arr_values} }}; \n"
+            f"static constexpr std::array< double, {self.E_neq.shape[0]} > E_neq {{ {E_neq_arr_values} }}; \n"
+            f"static constexpr std::array< double, {self.E_eq.shape[0]} > E_eq {{ {E_eq_arr_values} }}; \n"
+            f"static constexpr double inv_delta_E_eq = {self.inv_delta_E_eq}; \n"
+            f"static constexpr std::array< int, {self.idx_map.shape[0]} > idx_map {{ {idx_mapping_arr_values} }}; \n",
+
+            # Double lookup method
+            sfg.method("interpolateDL", returns=PsCustomType("[[nodiscard]] double"), inline=True, const=True)(
+                sfg.expr("return interpolate_double_lookup_cpp({}, *this);", E_target)
+            )
+        ]
+
+    def generate(self, sfg: SfgComposer):
+        """Generate C++ code for the interpolation container.
+        This method generates a C++ class with the necessary arrays and methods
+        for temperature-energy interpolation.
+        Args:
+            sfg (SfgComposer): Source file generator composer.
+        """
+        sfg.include("<array>")
+        sfg.include("pymatlib_interpolators/interpolate_binary_search_cpp.h")
+
+        public_members = self._generate_binary_search(sfg)
+
         # Add double lookup if available
         if self.has_double_lookup:
             sfg.include("pymatlib_interpolators/interpolate_double_lookup_cpp.h")
-
-            T_eq_arr_values = ", ".join(str(v) for v in self.T_eq)
-            E_neq_arr_values = ", ".join(str(v) for v in self.E_neq)
-            E_eq_arr_values = ", ".join(str(v) for v in self.E_eq)
-            idx_mapping_arr_values = ", ".join(str(v) for v in self.idx_map)
-
-            public_members.extend([
-                # Double lookup arrays
-                f"static constexpr std::array< double, {self.T_eq.shape[0]} > T_eq {{ {T_eq_arr_values} }}; \n"
-                f"static constexpr std::array< double, {self.E_neq.shape[0]} > E_neq {{ {E_neq_arr_values} }}; \n"
-                f"static constexpr std::array< double, {self.E_eq.shape[0]} > E_eq {{ {E_eq_arr_values} }}; \n"
-                f"static constexpr double inv_delta_E_eq = {self.inv_delta_E_eq}; \n"
-                f"static constexpr std::array< int, {self.idx_map.shape[0]} > idx_map {{ {idx_mapping_arr_values} }}; \n",
-
-                # Double lookup method
-                sfg.method("interpolateDL", returns=PsCustomType("[[nodiscard]] double"), inline=True, const=True)(
-                    sfg.expr("return interpolate_double_lookup_cpp({}, *this);", E_target)
-                )
-            ])
+            public_members.extend(self._generate_double_lookup(sfg))
 
         # Add interpolate method that uses recommended approach
+        E_target = sfg.var("E_target", "double")
         if self.has_double_lookup:
             public_members.append(
                 sfg.method("interpolate", returns=PsCustomType("[[nodiscard]] double"), inline=True, const=True)(
-- 
GitLab