diff --git a/hog/fem_helpers.py b/hog/fem_helpers.py index 56e0fae22664f0dc375c988dc26ce7305d6bd337..2a6ba498355dab13a72242cdb3853a14c8145e9a 100644 --- a/hog/fem_helpers.py +++ b/hog/fem_helpers.py @@ -388,7 +388,7 @@ def scalar_space_dependent_coefficient( elif isinstance(geometry, LineElement): coeff_class = ScalarVariableCoefficient3D - return coeff_class(sp.Symbol(name), 0, *t) + return coeff_class(sp.Symbol(f"{name}_std_function"), 0, *t) def vector_space_dependent_coefficient( diff --git a/hog/operator_generation/operators.py b/hog/operator_generation/operators.py index 18dbfede7ce05b0f9b57714e3c3203fd8f7a8918..77cdde1282af68cb81046aa57cd64a7e9c47cc24 100644 --- a/hog/operator_generation/operators.py +++ b/hog/operator_generation/operators.py @@ -24,6 +24,7 @@ from textwrap import indent import pystencils.astnodes import numpy as np import sympy as sp +from collections import defaultdict from hog.cpp_printing import ( CppClass, @@ -101,6 +102,7 @@ from hog.operator_generation.optimizer import Optimizer, Opts from hog.quadrature import QuadLoop, Quadrature from hog.symbolizer import Symbolizer from hog.operator_generation.types import HOGType +from hog.multi_assignment import MultiAssignment class MacroIntegrationDomain(Enum): @@ -219,7 +221,6 @@ def micro_vertex_permutation_for_facet( """ if volume_geometry == TriangleElement(): - if element_type == FaceType.BLUE: return [0, 1, 2] @@ -232,7 +233,6 @@ def micro_vertex_permutation_for_facet( return shuffle_order_gray[facet_id] elif volume_geometry == TetrahedronElement(): - if element_type == CellType.WHITE_DOWN: return [0, 1, 2, 3] @@ -300,6 +300,7 @@ class HyTeGElementwiseOperator: # coefficients self.coeffs: Dict[str, FunctionSpaceImpl] = {} + self.std_function_coeffs: Dict[str, MultiAssignment] = {} # implementations for each kernel, generated at a later stage self.operator_methods: List[OperatorMethod] = [] @@ -538,6 +539,9 @@ class HyTeGElementwiseOperator: """ return sorted(self.coeffs.values(), key=lambda c: c.name) + def std_function_coefficients(self) -> List[MultiAssignment]: + return self.std_function_coeffs + def generate_class_code( self, dir_path: str, @@ -556,7 +560,6 @@ class HyTeGElementwiseOperator: with TimedLogger( f"Generating kernels for operator {self.name}", level=logging.INFO ): - # Generate each kernel type (apply, gemv, ...). self.generate_kernels() @@ -601,7 +604,6 @@ class HyTeGElementwiseOperator: with TimedLogger("Generating C code from kernel AST(s)"): # Add all kernels to the class. for operator_method in self.operator_methods: - num_integrals = len(operator_method.integration_infos) if num_integrals != len( @@ -747,10 +749,23 @@ class HyTeGElementwiseOperator: ) for coeff in self.coefficients() ] + + [ + CppVariable( + name=f"_{std_coeff.variable_name()}", + type="std::function< real_t ( const Point3D & ) >", + is_const=True, + is_reference=False, + ) + for _, std_coeff in self.std_function_coeffs.items() + ] + free_symbol_vars + boundary_condition_vars, initializer_list=["Operator( storage, minLevel, maxLevel )"] + [f"{coeff.name}( _{coeff.name} )" for coeff in self.coefficients()] + + [ + f"{std_coeff.variable_name()}( _{std_coeff.variable_name()} )" + for _, std_coeff in self.std_function_coeffs.items() + ] + [ f"{fsv[0].name}( {fsv[1].name} )" for fsv in zip(free_symbol_vars_members, free_symbol_vars) @@ -776,6 +791,17 @@ class HyTeGElementwiseOperator: ) ) + for _, std_coeff in self.std_function_coeffs.items(): + operator_cpp_class.add( + CppMemberVariable( + CppVariable( + name=f"{std_coeff.variable_name()}_", + type="std::function< real_t ( const Point3D & ) >", + ), + visibility="private", + ) + ) + for fsv in free_symbol_vars_members: operator_cpp_class.add(CppMemberVariable(fsv, visibility="private")) @@ -797,7 +823,6 @@ class HyTeGElementwiseOperator: ) blending_includes = set() for dim, integration_infos in self.integration_infos.items(): - if not all( [ integration_infos[0].blending.coupling_includes() @@ -1193,7 +1218,6 @@ class HyTeGElementwiseOperator: element_types = list(integration_info.loop_strategy.element_loops.keys()) for element_type in element_types: - # Re-ordering micro-element vertices for the handling of domain boundary integrals. # # Boundary integrals are handled by looping over all (volume-)elements that have a facet at one of the @@ -1379,6 +1403,15 @@ class HyTeGElementwiseOperator: return_type=SympyAssignment, ) + for assignment in kernel_op_assignments: + if isinstance(assignment.rhs, MultiAssignment): + ma = assignment.rhs + self.std_function_coeffs[f"{ma.variable_name()}"] = ma + input_args = ma.args[2:] + assignment.rhs = sp.Symbol( + f"{ma.variable_name()}_(hyteg::Point3D({input_args[0]}, {input_args[1]}, {input_args[2] if len(input_args) == 3 else 0.0}))" + ) + body = ( loop_counter_custom_code_nodes + coords_assignments @@ -1529,7 +1562,6 @@ class HyTeGElementwiseOperator: for kernel_wrapper_type in self.kernel_wrapper_types: for dim, integration_infos in self.integration_infos.items(): - kernel_functions = [] kernel_op_counts = [] platform_dep_kernels = [] @@ -1545,13 +1577,11 @@ class HyTeGElementwiseOperator: ) for integration_info in integration_infos: - # generate AST of kernel loop with TimedLogger( f"Generating kernel {integration_info.name} ({kernel_wrapper_type.name}, {dim}D)", logging.INFO, ): - ( function_body, kernel_op_count, @@ -1653,7 +1683,6 @@ class HyTeGElementwiseOperator: for kernel_function, integration_info in zip( kernel_functions, integration_infos ): - pre_call_code = "" post_call_code = "" @@ -1661,7 +1690,6 @@ class HyTeGElementwiseOperator: integration_info.integration_domain == MacroIntegrationDomain.DOMAIN_BOUNDARY ): - if not isinstance(integration_info.loop_strategy, BOUNDARY): raise HOGException( "The loop strategy should be BOUNDARY for boundary integrals."