From b00da3b66eb5b2507604d4408e28c80e9d64ae2a Mon Sep 17 00:00:00 2001
From: Nils Kohl <nils.kohl@fau.de>
Date: Wed, 14 Aug 2024 10:24:41 +0200
Subject: [PATCH] Trying to fix return types of tabulation by going back to
 returning single-element matrices.

---
 hog/quadrature/tabulation.py | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/hog/quadrature/tabulation.py b/hog/quadrature/tabulation.py
index f623e46..cae9b41 100644
--- a/hog/quadrature/tabulation.py
+++ b/hog/quadrature/tabulation.py
@@ -56,17 +56,14 @@ class Tabulation:
         self.tables: Dict[str, Table] = {}
 
     def register_factor(
-        self, factor_name: str, factor: sp.Matrix | sp.Expr | int | float
-    ) -> sp.Matrix | int | float:
+        self, factor_name: str, factor: sp.Matrix | int | float
+    ) -> sp.Matrix:
         """Register a factor of the weak form that can be tabulated. Returns
         symbols replacing the expression for the factor. The symbols are returned
         in the same form as the factor was given. E.g. in case of a blended full
         Stokes operator we might encounter J_F^-1 grad phi being a matrix."""
 
-        if isinstance(factor, (int, float)):
-            return factor
-
-        if isinstance(factor, sp.Expr):
+        if not isinstance(factor, sp.MatrixBase):
             factor = sp.Matrix([factor])
 
         if all(f.is_constant() for f in factor):
@@ -80,9 +77,6 @@ class Tabulation:
                 table = self.tables.setdefault(table_name, Table(table_name))
                 replacement_symbols[r, c] = table.insert(factor[r, c])
 
-        if replacement_symbols.shape == (1, 1):
-            replacement_symbols = replacement_symbols[0]
-
         return replacement_symbols
 
     def construct_tables(
-- 
GitLab