From 5c8e0d9adf0d96a6b30e6798457a4ffae18c1e8d Mon Sep 17 00:00:00 2001
From: Marcus Mohr <marcus.mohr@lmu.de>
Date: Mon, 3 Mar 2025 15:27:48 +0100
Subject: [PATCH] Draft for resolving issue 48

The commit
- adds component_index as keyword to IntegrandSymbols
- makes gradient and divergence form pass component_index to process_integrand
- the corresponding integrands now check the shape of grad_u resp. grad_v
  and, if it is not square, revert to extracting the corresponding component
- moves the is_implemented check from the call of HyTeGIntegrator outwards,
  so that the call to form_func is protected against out-of-bounds indices
---
 generate_all_hyteg_forms.py                 | 38 ++++++++++-----------
 hog/forms.py                                |  2 ++
 hog/integrand.py                            |  9 ++++-
 hog/recipes/integrands/volume/divergence.py | 21 +++++++++---
 hog/recipes/integrands/volume/gradient.py   | 21 +++++++++---
 5 files changed, 60 insertions(+), 31 deletions(-)

diff --git a/generate_all_hyteg_forms.py b/generate_all_hyteg_forms.py
index b37446c..96a7b26 100644
--- a/generate_all_hyteg_forms.py
+++ b/generate_all_hyteg_forms.py
@@ -1135,30 +1135,28 @@ def main():
                                 inline_values=form_info.inline_quad,
                             )
 
-                            mat = form_func(
-                                form_info.form_name,
-                                row,
-                                col,
-                                trial,
-                                test,
-                                geometry,
-                                quad,
-                                symbolizer,
-                                blending=form_info.blending,
-                            )
-                            form_codes.append(
-                                HyTeGIntegrator(
-                                    form_info.class_name(row, col),
-                                    mat,
+                            if form_info.is_implemented( row, col, geometry.dimensions ):
+                                mat = form_func(
+                                    form_info.form_name,
+                                    row,
+                                    col,
+                                    trial,
+                                    test,
                                     geometry,
                                     quad,
                                     symbolizer,
-                                    not_implemented=not form_info.is_implemented(
-                                        row, col, geometry.dimensions
-                                    ),
-                                    integrate_rows=form_info.integrate_rows,
+                                    blending=form_info.blending,
+                                )
+                                form_codes.append(
+                                    HyTeGIntegrator(
+                                        form_info.class_name(row, col),
+                                        mat,
+                                        geometry,
+                                        quad,
+                                        symbolizer,
+                                        integrate_rows=form_info.integrate_rows,
+                                    )
                                 )
-                            )
                 form_classes.append(
                     HyTeGFormClass(
                         form_info.class_name(row, col),
diff --git a/hog/forms.py b/hog/forms.py
index ec630ce..5ed4b0e 100644
--- a/hog/forms.py
+++ b/hog/forms.py
@@ -579,6 +579,7 @@ Weak formulation
         geometry,
         symbolizer,
         blending=blending,
+        component_index=component_index,
         is_symmetric=False,
         docstring=docstring,
         rot_type=RotationType.POST_MULTIPLY
@@ -620,6 +621,7 @@ def gradient(
         geometry,
         symbolizer,
         blending=blending,
+        component_index=component_index,
         is_symmetric=False,
         docstring=docstring,
         rot_type=RotationType.PRE_MULTIPLY
diff --git a/hog/integrand.py b/hog/integrand.py
index 0731370..a95b193 100644
--- a/hog/integrand.py
+++ b/hog/integrand.py
@@ -220,6 +220,9 @@ class IntegrandSymbols:
     #
     tabulate: Callable[[Union[sp.Expr, sp.Matrix], str], sp.Matrix] | None = None
 
+    # For backward compatibility with (sub-)form generation this integer allows to select a component
+    component_index: int | None = None
+
 
 def process_integrand(
     integrand: Callable[..., Any],
@@ -230,6 +233,7 @@ def process_integrand(
     blending: GeometryMap = IdentityMap(),
     boundary_geometry: ElementGeometry | None = None,
     fe_coefficients: Dict[str, Union[FunctionSpace, None]] | None = None,
+    component_index: int | None = None,
     is_symmetric: bool = False,
     rot_type: RotationType = RotationType.NO_ROTATION,
     docstring: str = "",
@@ -293,7 +297,7 @@ def process_integrand(
                             finite-element function coefficients, they will be available to the callable as `k`
                             supply None as the FunctionSpace for a std::function-type coeff (only works for old forms)
     :param is_symmetric: whether the bilinear form is symmetric - this is exploited by the generator
-    :param rot_type: whether the  operator has to be wrapped with rotation matrix and the type of rotation that needs 
+    :param rot_type: whether the  operator has to be wrapped with rotation matrix and the type of rotation that needs
                      to be applied, only applicable for Vectorial spaces
     :param docstring: documentation of the integrand/bilinear form - will end up in the docstring of the generated code
     """
@@ -482,6 +486,9 @@ def process_integrand(
     mat = create_empty_element_matrix(trial, test, volume_geometry)
     it = element_matrix_iterator(trial, test, volume_geometry)
 
+    if component_index is not None:
+        s.component_index = component_index
+
     for data in it:
         s.u = data.trial_shape
         s.grad_u = data.trial_shape_grad
diff --git a/hog/recipes/integrands/volume/divergence.py b/hog/recipes/integrands/volume/divergence.py
index 75ae751..f2ea653 100644
--- a/hog/recipes/integrands/volume/divergence.py
+++ b/hog/recipes/integrands/volume/divergence.py
@@ -26,10 +26,21 @@ def integrand(
     grad_u,
     v,
     tabulate,
+    component_index,
     **_,
 ):
-    return (
-        -(jac_b_inv.T * tabulate(jac_a_inv.T * grad_u)).trace()
-        * tabulate(v * jac_a_abs_det)
-        * jac_b_abs_det
-    )
+    # working with vector-valued functions
+    if grad_u.is_square:
+        return (
+            -(jac_b_inv.T * tabulate(jac_a_inv.T * grad_u)).trace()
+            * tabulate(v * jac_a_abs_det)
+            * jac_b_abs_det
+        )
+
+    # working with scalar-valued functions (backward compatibility)
+    else:
+        return (
+            -(jac_b_inv.T * tabulate(jac_a_inv.T * grad_u))[component_index]
+            * tabulate(v * jac_a_abs_det)
+            * jac_b_abs_det
+        )
diff --git a/hog/recipes/integrands/volume/gradient.py b/hog/recipes/integrands/volume/gradient.py
index 982ba23..134271d 100644
--- a/hog/recipes/integrands/volume/gradient.py
+++ b/hog/recipes/integrands/volume/gradient.py
@@ -26,10 +26,21 @@ def integrand(
     u,
     grad_v,
     tabulate,
+    component_index,
     **_,
 ):
-    return (
-        -(jac_b_inv.T * tabulate(jac_a_inv.T * grad_v)).trace()
-        * tabulate(u * jac_a_abs_det)
-        * jac_b_abs_det
-    )
+    # working with vector-valued functions
+    if grad_v.is_square:
+        return (
+            -(jac_b_inv.T * tabulate(jac_a_inv.T * grad_v)).trace()
+            * tabulate(u * jac_a_abs_det)
+            * jac_b_abs_det
+        )
+
+    # working with scalar-valued functions (backward compatibility)
+    else:
+        return (
+            -(jac_b_inv.T * tabulate(jac_a_inv.T * grad_v))[component_index]
+            * tabulate(u * jac_a_abs_det)
+            * jac_b_abs_det
+        )
-- 
GitLab