From ceccff04001904996837ee86754b8344c2a32dbf Mon Sep 17 00:00:00 2001 From: Rahil Doshi <rahil.doshi@fau.de> Date: Wed, 2 Apr 2025 17:00:05 +0200 Subject: [PATCH] Refactor interpolation_array_container.py --- .../codegen/interpolation_array_container.py | 87 ++++--------------- 1 file changed, 15 insertions(+), 72 deletions(-) diff --git a/src/pymatlib/core/codegen/interpolation_array_container.py b/src/pymatlib/core/codegen/interpolation_array_container.py index 3912ff0..ffc1171 100644 --- a/src/pymatlib/core/codegen/interpolation_array_container.py +++ b/src/pymatlib/core/codegen/interpolation_array_container.py @@ -1,5 +1,6 @@ import re import numpy as np +from typing import List, Any from pystencils.types import PsCustomType from pystencilssfg import SfgComposer from pystencilssfg.composer.custom import CustomGenerator @@ -7,55 +8,9 @@ from pymatlib.core.interpolators import prepare_interpolation_arrays class InterpolationArrayContainer(CustomGenerator): - """Container for x-y interpolation arrays and methods. - - This class stores x and y arrays and generates C++ code - for efficient conversion to compute y for a given x. It supports both - binary search interpolation (O(log n)) and double lookup interpolation (O(1)) - with automatic method selection based on data characteristics. - - Attributes: - name (str): Name for the generated C++ class. - x_array (np.ndarray): Array of x values (must be monotonically increasing). - y_array (np.ndarray): Array of y values corresponding to x_array. - method (str): Interpolation method selected ("binary_search" or "double_lookup"). - x_bs (np.ndarray): x array prepared for binary search. - y_bs (np.ndarray): y array prepared for binary search. - has_double_lookup (bool): Whether double lookup interpolation is available. - - If has_double_lookup is True, the following attributes are also available: - x_eq (np.ndarray): Equidistant x array for double lookup. - y_neq (np.ndarray): Non-equidistant y array for double lookup. - y_eq (np.ndarray): Equidistant y array for double lookup. - inv_delta_y_eq (float): Inverse of the y step size for double lookup. - idx_map (np.ndarray): Index mapping array for double lookup. - - Examples: - >>> import numpy as np - >>> from pystencilssfg import SfgComposer - >>> from pymatlib.core.codegen.interpolation_array_container import InterpolationArrayContainer - >>> - >>> # Create temperature and energy arrays - >>> T = np.array([300, 600, 900, 1200], dtype=np.float64) - >>> E = np.array([1e9, 2e9, 3e9, 4e9], dtype=np.float64) - >>> - >>> # Create and generate the container - >>> with SfgComposer() as sfg: - >>> container = InterpolationArrayContainer("MyMaterial", T, E) - >>> sfg.generate(container) - """ + """Container for x-y interpolation arrays and methods.""" def __init__(self, name: str, x_array: np.ndarray, y_array: np.ndarray): - """Initialize the interpolation container. - Args: - name (str): Name for the generated C++ class. - x_array (np.ndarray): Array of x values. - Must be monotonically increasing. - y_array (np.ndarray): Array of y values - corresponding to x_array. - Raises: - ValueError: If arrays are empty, have different lengths, or are not monotonic. - TypeError: If name is not a string or arrays are not numpy arrays. - """ + """Initialize the interpolation container.""" super().__init__() # Validate inputs @@ -66,7 +21,7 @@ class InterpolationArrayContainer(CustomGenerator): raise ValueError(f"'{name}' is not a valid C++ class name") if not isinstance(x_array, np.ndarray) or not isinstance(y_array, np.ndarray): - raise TypeError("Temperature and energy arrays must be numpy arrays") + raise TypeError("x_array and y_array must be numpy arrays") self.name = name self.x_array = x_array @@ -107,13 +62,8 @@ class InterpolationArrayContainer(CustomGenerator): """ return cls(name, material.energy_density_temperature_array, material.y_array)''' - def _generate_binary_search(self, sfg: SfgComposer): - """Generate code for binary search interpolation. - Args: - sfg (SfgComposer): Source file generator composer. - Returns: - list: List of public members for the C++ class. - """ + def _generate_binary_search(self, sfg: SfgComposer) -> List[Any]: + """Generate code for binary search interpolation.""" x_bs_arr_values = ", ".join(str(v) for v in self.x_bs) y_bs_arr_values = ", ".join(str(v) for v in self.y_bs) @@ -130,13 +80,8 @@ class InterpolationArrayContainer(CustomGenerator): ) ] - def _generate_double_lookup(self, sfg: SfgComposer): - """Generate code for double lookup interpolation. - Args: - sfg (SfgComposer): Source file generator composer. - Returns: - list: List of public members for the C++ class. - """ + def _generate_double_lookup(self, sfg: SfgComposer) -> List[Any]: + """Generate code for double lookup interpolation.""" if not self.has_double_lookup: return [] @@ -161,26 +106,24 @@ class InterpolationArrayContainer(CustomGenerator): ) ] - def generate(self, sfg: SfgComposer): - """Generate C++ code for the interpolation container. - This method generates a C++ class with the necessary arrays and methods - for temperature-energy interpolation. - Args: - sfg (SfgComposer): Source file generator composer. - """ + def generate(self, sfg: SfgComposer) -> None: + """Generate C++ code for the interpolation container.""" sfg.include("<array>") sfg.include("pymatlib_interpolators/interpolate_binary_search_cpp.h") public_members = self._generate_binary_search(sfg) # Add double lookup if available - if self.has_double_lookup: + """if self.has_double_lookup: sfg.include("pymatlib_interpolators/interpolate_double_lookup_cpp.h") - public_members.extend(self._generate_double_lookup(sfg)) + public_members.extend(self._generate_double_lookup(sfg))""" # Add interpolate method that uses recommended approach y_target = sfg.var("y_target", "double") if self.has_double_lookup: + sfg.include("pymatlib_interpolators/interpolate_double_lookup_cpp.h") + public_members.extend(self._generate_double_lookup(sfg)) + public_members.append( sfg.method("interpolate", returns=PsCustomType("[[nodiscard]] double"), inline=True, const=True)( sfg.expr("return interpolate_double_lookup_cpp({}, *this);", y_target) -- GitLab