diff --git a/src/pymatlib/core/interpolators.py b/src/pymatlib/core/interpolators.py index 78e89b95eaa7092d0c13042757cdcb6ea646aacb..e517442bf54d8c8f016f4c8056da8a0c7459cc4a 100644 --- a/src/pymatlib/core/interpolators.py +++ b/src/pymatlib/core/interpolators.py @@ -19,13 +19,24 @@ class InterpolationArrayContainer(CustomGenerator): self.T_array = temperature_array self.E_array = energy_density_array - # Prepare arrays for double lookup - self.T_eq, self.E_neq, self.E_eq, self.inv_delta_E_eq, self.idx_mapping = \ - prepare_interpolation_arrays(self.T_array, self.E_array) - - # Store original arrays for binary search - self.T_bs = self.T_array - self.E_bs = self.E_array + # Prepare arrays and determine best method + self.data = prepare_interpolation_arrays(self.T_array, self.E_array) + self.method = self.data["method"] + + # Store arrays for binary search (always available) + self.T_bs = self.data["T_bs"] + self.E_bs = self.data["E_bs"] + + # Store arrays for double lookup if available + if self.method == "double_lookup": + self.T_eq = self.data["T_eq"] + self.E_neq = self.data["E_neq"] + self.E_eq = self.data["E_eq"] + self.inv_delta_E_eq = self.data["inv_delta_E_eq"] + self.idx_map = self.data["idx_map"] + self.has_double_lookup = True + else: + self.has_double_lookup = False @classmethod def from_material(cls, name: str, material): @@ -33,49 +44,65 @@ class InterpolationArrayContainer(CustomGenerator): def generate(self, sfg: SfgComposer): sfg.include("<array>") - sfg.include("interpolate_double_lookup_cpp.h") sfg.include("interpolate_binary_search_cpp.h") - T_eq_arr_values = ", ".join(str(v) for v in self.T_eq) - E_neq_arr_values = ", ".join(str(v) for v in self.E_neq) - E_eq_arr_values = ", ".join(str(v) for v in self.E_eq) - idx_mapping_arr_values = ", ".join(str(v) for v in self.idx_mapping) - - # Binary search arrays + # Binary search arrays (always included) T_bs_arr_values = ", ".join(str(v) for v in self.T_bs) E_bs_arr_values = ", ".join(str(v) for v in self.E_bs) E_target = sfg.var("E_target", "double") - sfg.klass(self.name)( + public_members = [ + # Binary search arrays + f"static constexpr std::array< double, {self.T_bs.shape[0]} > T_bs {{ {T_bs_arr_values} }}; \n" + f"static constexpr std::array< double, {self.E_bs.shape[0]} > E_bs {{ {E_bs_arr_values} }}; \n", + + # Binary search method + sfg.method("interpolateBS", returns=PsCustomType("double"), inline=True, const=True)( + sfg.expr("return interpolate_binary_search_cpp({}, *this);", E_target) + ) + ] + + # Add double lookup if available + if self.has_double_lookup: + sfg.include("interpolate_double_lookup_cpp.h") + + T_eq_arr_values = ", ".join(str(v) for v in self.T_eq) + E_neq_arr_values = ", ".join(str(v) for v in self.E_neq) + E_eq_arr_values = ", ".join(str(v) for v in self.E_eq) + idx_mapping_arr_values = ", ".join(str(v) for v in self.idx_map) - sfg.public( + public_members.extend([ # Double lookup arrays f"static constexpr std::array< double, {self.T_eq.shape[0]} > T_eq = {{ {T_eq_arr_values} }}; \n" f"static constexpr std::array< double, {self.E_neq.shape[0]} > E_neq = {{ {E_neq_arr_values} }}; \n" f"static constexpr std::array< double, {self.E_eq.shape[0]} > E_eq = {{ {E_eq_arr_values} }}; \n" f"static constexpr double inv_delta_E_eq = {self.inv_delta_E_eq}; \n" - f"static constexpr std::array< int, {self.idx_mapping.shape[0]} > idx_map = {{ {idx_mapping_arr_values} }}; \n", - - # Binary search arrays - f"static constexpr std::array< double, {self.T_bs.shape[0]} > T_bs = {{ {T_bs_arr_values} }}; \n" - f"static constexpr std::array< double, {self.E_bs.shape[0]} > E_bs = {{ {E_bs_arr_values} }}; \n", - - #TODO! - # create constructor - # sfg.constructor(self.T_eq, self.E_neq, self.E_eq, self.inv_delta_E_eq, self.idx_mapping) - # sfg.constructor(sfg.var("T_eq", PsCustomType("std::array< double, N >"))), + f"static constexpr std::array< int, {self.idx_map.shape[0]} > idx_map = {{ {idx_mapping_arr_values} }}; \n", # Double lookup method sfg.method("interpolateDL", returns=PsCustomType("double"), inline=True, const=True)( sfg.expr("return interpolate_double_lookup_cpp({}, *this);", E_target) - ), + ) + ]) - # Binary search method - sfg.method("interpolateBS", returns=PsCustomType("double"), inline=True, const=True)( + # Add interpolate method that uses recommended approach + if self.has_double_lookup: + public_members.append( + sfg.method("interpolate", returns=PsCustomType("double"), inline=True, const=True)( + sfg.expr("return interpolate_double_lookup_cpp({}, *this);", E_target) + ) + ) + else: + public_members.append( + sfg.method("interpolate", returns=PsCustomType("double"), inline=True, const=True)( sfg.expr("return interpolate_binary_search_cpp({}, *this);", E_target) ) ) + + # Generate the class + sfg.klass(self.name)( + sfg.public(*public_members) )