From ceccff04001904996837ee86754b8344c2a32dbf Mon Sep 17 00:00:00 2001
From: Rahil Doshi <rahil.doshi@fau.de>
Date: Wed, 2 Apr 2025 17:00:05 +0200
Subject: [PATCH] Refactor interpolation_array_container.py

---
 .../codegen/interpolation_array_container.py  | 87 ++++---------------
 1 file changed, 15 insertions(+), 72 deletions(-)

diff --git a/src/pymatlib/core/codegen/interpolation_array_container.py b/src/pymatlib/core/codegen/interpolation_array_container.py
index 3912ff0..ffc1171 100644
--- a/src/pymatlib/core/codegen/interpolation_array_container.py
+++ b/src/pymatlib/core/codegen/interpolation_array_container.py
@@ -1,5 +1,6 @@
 import re
 import numpy as np
+from typing import List, Any
 from pystencils.types import PsCustomType
 from pystencilssfg import SfgComposer
 from pystencilssfg.composer.custom import CustomGenerator
@@ -7,55 +8,9 @@ from pymatlib.core.interpolators import prepare_interpolation_arrays
 
 
 class InterpolationArrayContainer(CustomGenerator):
-    """Container for x-y interpolation arrays and methods.
-
-    This class stores x and y arrays and generates C++ code
-    for efficient conversion to compute y for a given x. It supports both
-    binary search interpolation (O(log n)) and double lookup interpolation (O(1))
-    with automatic method selection based on data characteristics.
-
-    Attributes:
-        name (str): Name for the generated C++ class.
-        x_array (np.ndarray): Array of x values (must be monotonically increasing).
-        y_array (np.ndarray): Array of y values corresponding to x_array.
-        method (str): Interpolation method selected ("binary_search" or "double_lookup").
-        x_bs (np.ndarray): x array prepared for binary search.
-        y_bs (np.ndarray): y array prepared for binary search.
-        has_double_lookup (bool): Whether double lookup interpolation is available.
-
-    If has_double_lookup is True, the following attributes are also available:
-        x_eq (np.ndarray): Equidistant x array for double lookup.
-        y_neq (np.ndarray): Non-equidistant y array for double lookup.
-        y_eq (np.ndarray): Equidistant y array for double lookup.
-        inv_delta_y_eq (float): Inverse of the y step size for double lookup.
-        idx_map (np.ndarray): Index mapping array for double lookup.
-
-    Examples:
-        >>> import numpy as np
-        >>> from pystencilssfg import SfgComposer
-        >>> from pymatlib.core.codegen.interpolation_array_container import InterpolationArrayContainer
-        >>>
-        >>> # Create temperature and energy arrays
-        >>> T = np.array([300, 600, 900, 1200], dtype=np.float64)
-        >>> E = np.array([1e9, 2e9, 3e9, 4e9], dtype=np.float64)
-        >>>
-        >>> # Create and generate the container
-        >>> with SfgComposer() as sfg:
-        >>>     container = InterpolationArrayContainer("MyMaterial", T, E)
-        >>>     sfg.generate(container)
-    """
+    """Container for x-y interpolation arrays and methods."""
     def __init__(self, name: str, x_array: np.ndarray, y_array: np.ndarray):
-        """Initialize the interpolation container.
-        Args:
-            name (str): Name for the generated C++ class.
-            x_array (np.ndarray): Array of x values.
-                Must be monotonically increasing.
-            y_array (np.ndarray): Array of y values
-                corresponding to x_array.
-        Raises:
-            ValueError: If arrays are empty, have different lengths, or are not monotonic.
-            TypeError: If name is not a string or arrays are not numpy arrays.
-        """
+        """Initialize the interpolation container."""
         super().__init__()
 
         # Validate inputs
@@ -66,7 +21,7 @@ class InterpolationArrayContainer(CustomGenerator):
             raise ValueError(f"'{name}' is not a valid C++ class name")
 
         if not isinstance(x_array, np.ndarray) or not isinstance(y_array, np.ndarray):
-            raise TypeError("Temperature and energy arrays must be numpy arrays")
+            raise TypeError("x_array and y_array must be numpy arrays")
 
         self.name = name
         self.x_array = x_array
@@ -107,13 +62,8 @@ class InterpolationArrayContainer(CustomGenerator):
         """
         return cls(name, material.energy_density_temperature_array, material.y_array)'''
 
-    def _generate_binary_search(self, sfg: SfgComposer):
-        """Generate code for binary search interpolation.
-        Args:
-            sfg (SfgComposer): Source file generator composer.
-        Returns:
-            list: List of public members for the C++ class.
-        """
+    def _generate_binary_search(self, sfg: SfgComposer) -> List[Any]:
+        """Generate code for binary search interpolation."""
         x_bs_arr_values = ", ".join(str(v) for v in self.x_bs)
         y_bs_arr_values = ", ".join(str(v) for v in self.y_bs)
 
@@ -130,13 +80,8 @@ class InterpolationArrayContainer(CustomGenerator):
             )
         ]
 
-    def _generate_double_lookup(self, sfg: SfgComposer):
-        """Generate code for double lookup interpolation.
-        Args:
-            sfg (SfgComposer): Source file generator composer.
-        Returns:
-            list: List of public members for the C++ class.
-        """
+    def _generate_double_lookup(self, sfg: SfgComposer) -> List[Any]:
+        """Generate code for double lookup interpolation."""
         if not self.has_double_lookup:
             return []
 
@@ -161,26 +106,24 @@ class InterpolationArrayContainer(CustomGenerator):
             )
         ]
 
-    def generate(self, sfg: SfgComposer):
-        """Generate C++ code for the interpolation container.
-        This method generates a C++ class with the necessary arrays and methods
-        for temperature-energy interpolation.
-        Args:
-            sfg (SfgComposer): Source file generator composer.
-        """
+    def generate(self, sfg: SfgComposer) -> None:
+        """Generate C++ code for the interpolation container."""
         sfg.include("<array>")
         sfg.include("pymatlib_interpolators/interpolate_binary_search_cpp.h")
 
         public_members = self._generate_binary_search(sfg)
 
         # Add double lookup if available
-        if self.has_double_lookup:
+        """if self.has_double_lookup:
             sfg.include("pymatlib_interpolators/interpolate_double_lookup_cpp.h")
-            public_members.extend(self._generate_double_lookup(sfg))
+            public_members.extend(self._generate_double_lookup(sfg))"""
 
         # Add interpolate method that uses recommended approach
         y_target = sfg.var("y_target", "double")
         if self.has_double_lookup:
+            sfg.include("pymatlib_interpolators/interpolate_double_lookup_cpp.h")
+            public_members.extend(self._generate_double_lookup(sfg))
+
             public_members.append(
                 sfg.method("interpolate", returns=PsCustomType("[[nodiscard]] double"), inline=True, const=True)(
                     sfg.expr("return interpolate_double_lookup_cpp({}, *this);", y_target)
-- 
GitLab