From f77a5e493791c7f4c67be0fe44a96914b24c199b Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Fri, 26 Jul 2024 17:46:23 +0200
Subject: [PATCH] introduce NewTypes for Trial- and TestSpace such that mypy
 can check they are not confused; fix trial/test mix-ups

---
 generate_all_hyteg_forms.py                   |  25 ++--
 generate_all_operators.py                     | 112 +++++++++---------
 hog/fem_helpers.py                            |   6 +-
 hog/forms.py                                  | 103 ++++++++--------
 hog/forms_boundary.py                         |  22 ++--
 hog/forms_facets.py                           |  76 ++++++------
 hog/forms_facets_vectorial.py                 |  26 ++--
 hog/forms_vectorial.py                        |  36 +++---
 hog/function_space.py                         |  16 ++-
 hog/hyteg_form_template.py                    |  49 +++-----
 hog/integrand.py                              |   7 +-
 hog/manifold_forms.py                         |  38 +++---
 hog/operator_generation/indexing.py           |   3 -
 hog/operator_generation/kernel_types.py       |  40 +++----
 .../operator_generation/test_boundary_loop.py |  31 ++---
 .../operator_generation/test_indexing.py      |   4 +-
 .../operator_generation/test_opgen_smoke.py   |  26 ++--
 hog_tests/test_diffusion.py                   |  12 +-
 hog_tests/test_function_spaces.py             |   8 +-
 hog_tests/test_pspg.py                        |  13 +-
 hog_tests/test_quadrature.py                  |  17 ++-
 21 files changed, 345 insertions(+), 325 deletions(-)

diff --git a/generate_all_hyteg_forms.py b/generate_all_hyteg_forms.py
index 7a08fe4..d96ca5d 100644
--- a/generate_all_hyteg_forms.py
+++ b/generate_all_hyteg_forms.py
@@ -31,7 +31,12 @@ from hog.element_geometry import (
     TetrahedronElement,
     ElementGeometry,
 )
-from hog.function_space import FunctionSpace, LagrangianFunctionSpace, N1E1Space
+from hog.function_space import (
+    LagrangianFunctionSpace,
+    N1E1Space,
+    TrialSpace,
+    TestSpace,
+)
 from hog.forms import (
     mass,
     diffusion,
@@ -737,8 +742,8 @@ def form_func(
     name: str,
     row: int,
     col: int,
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     quad: Quadrature,
     symbolizer: Symbolizer,
@@ -1069,17 +1074,19 @@ def main():
     for form_info in filtered_form_infos:
         logger.info(f"{form_info}")
 
-        trial: FunctionSpace
+        trial: TrialSpace
         if form_info.trial_family == "N1E1":
-            trial = N1E1Space(symbolizer)
+            trial = TrialSpace(N1E1Space(symbolizer))
         else:
-            trial = LagrangianFunctionSpace(form_info.trial_degree, symbolizer)
+            trial = TrialSpace(
+                LagrangianFunctionSpace(form_info.trial_degree, symbolizer)
+            )
 
-        test: FunctionSpace
+        test: TestSpace
         if form_info.test_family == "N1E1":
-            test = N1E1Space(symbolizer)
+            test = TestSpace(N1E1Space(symbolizer))
         else:
-            test = LagrangianFunctionSpace(form_info.test_degree, symbolizer)
+            test = TestSpace(LagrangianFunctionSpace(form_info.test_degree, symbolizer))
 
         form_classes = []
 
diff --git a/generate_all_operators.py b/generate_all_operators.py
index 1467ac5..005b4d3 100644
--- a/generate_all_operators.py
+++ b/generate_all_operators.py
@@ -59,10 +59,11 @@ from hog.forms_boundary import (
 )
 from hog.forms_vectorial import curl_curl, curl_curl_plus_mass, mass_n1e1
 from hog.function_space import (
-    FunctionSpace,
     LagrangianFunctionSpace,
     N1E1Space,
     TensorialVectorFunctionSpace,
+    TrialSpace,
+    TestSpace,
 )
 from hog.logger import get_logger, TimedLogger
 from hog.operator_generation.kernel_types import (
@@ -490,13 +491,13 @@ def all_opts_both_cses() -> List[Tuple[Set[Opts], LoopStrategy, str]]:
 class OperatorInfo:
     mapping: str
     name: str
-    trial_space: FunctionSpace
-    test_space: FunctionSpace
+    trial_space: TrialSpace
+    test_space: TestSpace
     form: (
         Callable[
             [
-                FunctionSpace,
-                FunctionSpace,
+                TrialSpace,
+                TestSpace,
                 ElementGeometry,
                 Symbolizer,
                 GeometryMap,
@@ -509,8 +510,8 @@ class OperatorInfo:
     form_boundary: (
         Callable[
             [
-                FunctionSpace,
-                FunctionSpace,
+                TrialSpace,
+                TestSpace,
                 ElementGeometry,
                 ElementGeometry,
                 Symbolizer,
@@ -588,45 +589,45 @@ def all_operators(
 
     # fmt: off
     # TODO switch to manual specification of opts for now/developement, later use default factory
-    ops.append(OperatorInfo(mapping="N1E1", name="CurlCurl", trial_space=N1E1, test_space=N1E1, form=curl_curl,
+    ops.append(OperatorInfo("N1E1", "CurlCurl", TrialSpace(N1E1), TestSpace(N1E1), form=curl_curl,
                             type_descriptor=type_descriptor, geometries=three_d, opts=opts, blending=blending))
-    ops.append(OperatorInfo(mapping="N1E1", name="Mass", trial_space=N1E1, test_space=N1E1, form=mass_n1e1,
+    ops.append(OperatorInfo("N1E1", "Mass", TrialSpace(N1E1), TestSpace(N1E1), form=mass_n1e1,
                             type_descriptor=type_descriptor, geometries=three_d, opts=opts, blending=blending))
-    ops.append(OperatorInfo(mapping="N1E1", name="CurlCurlPlusMass", trial_space=N1E1, test_space=N1E1,
+    ops.append(OperatorInfo("N1E1", "CurlCurlPlusMass", TrialSpace(N1E1), TestSpace(N1E1),
                             form=partial(curl_curl_plus_mass, alpha_fem_space=P1, beta_fem_space=P1),
                             type_descriptor=type_descriptor, geometries=three_d, opts=opts, blending=blending))
-    ops.append(OperatorInfo(mapping="P1", name="Diffusion", trial_space=P1, test_space=P1, form=diffusion,
+    ops.append(OperatorInfo("P1", "Diffusion", TrialSpace(P1), TestSpace(P1), form=diffusion,
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
-    ops.append(OperatorInfo(mapping="P1", name="DivKGrad", trial_space=P1, test_space=P1,
+    ops.append(OperatorInfo("P1", "DivKGrad", TrialSpace(P1), TestSpace(P1),
                             form=partial(div_k_grad, coefficient_function_space=P1),
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
 
-    ops.append(OperatorInfo(mapping="P2", name="Diffusion", trial_space=P2, test_space=P2, form=diffusion,
+    ops.append(OperatorInfo("P2", "Diffusion", TrialSpace(P2), TestSpace(P2), form=diffusion,
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
-    ops.append(OperatorInfo(mapping="P2", name="DivKGrad", trial_space=P2, test_space=P2,
+    ops.append(OperatorInfo("P2", "DivKGrad", TrialSpace(P2), TestSpace(P2),
                             form=partial(div_k_grad, coefficient_function_space=P2),
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
 
-    ops.append(OperatorInfo(mapping="P2", name="ShearHeating", trial_space=P2, test_space=P2,
+    ops.append(OperatorInfo("P2", "ShearHeating", TrialSpace(P2), TestSpace(P2),
                             form=partial(shear_heating, viscosity_function_space=P2, velocity_function_space=P2),
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
 
-    ops.append(OperatorInfo(mapping="P1", name="NonlinearDiffusion", trial_space=P1, test_space=P1,
+    ops.append(OperatorInfo("P1", "NonlinearDiffusion", TrialSpace(P1), TestSpace(P1),
                             form=partial(nonlinear_diffusion, coefficient_function_space=P1),
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
-    ops.append(OperatorInfo(mapping="P1", name="NonlinearDiffusionNewtonGalerkin", trial_space=P1,
-                            test_space=P1, form=partial(nonlinear_diffusion_newton_galerkin,
+    ops.append(OperatorInfo("P1", "NonlinearDiffusionNewtonGalerkin", TrialSpace(P1),
+                            TestSpace(P1), form=partial(nonlinear_diffusion_newton_galerkin,
                             coefficient_function_space=P1, onlyNewtonGalerkinPartOfForm=False),
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
 
-    ops.append(OperatorInfo(mapping="P1Vector", name="Diffusion", trial_space=P1Vector, test_space=P1Vector,
+    ops.append(OperatorInfo("P1Vector", "Diffusion", TrialSpace(P1Vector), TestSpace(P1Vector),
                             form=diffusion, type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
 
-    ops.append(OperatorInfo(mapping="P2", name="SUPGDiffusion", trial_space=P2, test_space=P2, 
+    ops.append(OperatorInfo("P2", "SUPGDiffusion", TrialSpace(P2), TestSpace(P2), 
                             form=partial(supg_diffusion, velocity_function_space=P2, diffusivityXdelta_function_space=P2), 
                             type_descriptor=type_descriptor, geometries=list(geometries), opts=opts, blending=blending))
 
-    ops.append(OperatorInfo(mapping="P2", name="BoundaryMass", trial_space=P2, test_space=P2, form=None,
+    ops.append(OperatorInfo("P2", "BoundaryMass", TrialSpace(P2), TestSpace(P2), form=None,
                             form_boundary=mass_boundary, type_descriptor=type_descriptor, geometries=list(geometries),
                             opts=opts, blending=blending))
 
@@ -639,10 +640,10 @@ def all_operators(
     )
     ops.append(
         OperatorInfo(
-            mapping=f"P2Vector",
-            name=f"Epsilon",
-            trial_space=P2Vector,
-            test_space=P2Vector,
+            f"P2Vector",
+            f"Epsilon",
+            TrialSpace(P2Vector),
+            TestSpace(P2Vector),
             form=p2vec_epsilon,
             type_descriptor=type_descriptor,
             geometries=list(geometries),
@@ -657,10 +658,10 @@ def all_operators(
 
     ops.append(
         OperatorInfo(
-            mapping=f"P2Vector",
-            name=f"EpsilonFreeslip",
-            trial_space=P2Vector,
-            test_space=P2Vector,
+            f"P2Vector",
+            f"EpsilonFreeslip",
+            TrialSpace(P2Vector),
+            TestSpace(P2Vector),
             form=p2vec_epsilon,
             form_boundary=p2vec_freeslip_momentum_weak_boundary,
             type_descriptor=type_descriptor,
@@ -672,10 +673,10 @@ def all_operators(
 
     ops.append(
         OperatorInfo(
-            mapping=f"P2VectorToP1",
-            name=f"DivergenceFreeslip",
-            trial_space=P2Vector,
-            test_space=P1,
+            f"P2VectorToP1",
+            f"DivergenceFreeslip",
+            TrialSpace(P2Vector),
+            TestSpace(P1),
             form=divergence,
             form_boundary=freeslip_divergence_weak_boundary,
             type_descriptor=type_descriptor,
@@ -687,10 +688,10 @@ def all_operators(
 
     ops.append(
         OperatorInfo(
-            mapping=f"P1ToP2Vector",
-            name=f"GradientFreeslip",
-            trial_space=P1,
-            test_space=P2Vector,
+            f"P1ToP2Vector",
+            f"GradientFreeslip",
+            TrialSpace(P1),
+            TestSpace(P2Vector),
             form=gradient,
             form_boundary=freeslip_gradient_weak_boundary,
             type_descriptor=type_descriptor,
@@ -706,10 +707,10 @@ def all_operators(
     )
     ops.append(
         OperatorInfo(
-            mapping=f"P2VectorToP1",
-            name=f"GradRhoByRhoDotU",
-            trial_space=P1,
-            test_space=P2Vector,
+            f"P2VectorToP1",
+            f"GradRhoByRhoDotU",
+            TrialSpace(P1),
+            TestSpace(P2Vector),
             form=p2vec_grad_rho,
             type_descriptor=type_descriptor,
             geometries=list(geometries),
@@ -724,13 +725,13 @@ def all_operators(
             div_geometries = three_d
         else:
             div_geometries = list(geometries)
-        ops.append(OperatorInfo(mapping=f"P2ToP1", name=f"Div_{c}",
-                                trial_space=TensorialVectorFunctionSpace(P2, single_component=c), test_space=P1,
+        ops.append(OperatorInfo(f"P2ToP1", f"Div_{c}",
+                                TrialSpace(TensorialVectorFunctionSpace(P2, single_component=c)), TestSpace(P1),
                                 form=partial(divergence, component_index=c),
                                 type_descriptor=type_descriptor, opts=opts, geometries=div_geometries,
                                 blending=blending))
-        ops.append(OperatorInfo(mapping=f"P1ToP2", name=f"DivT_{c}", trial_space=P1,
-                                test_space=TensorialVectorFunctionSpace(P2, single_component=c),
+        ops.append(OperatorInfo(f"P1ToP2", f"DivT_{c}", TrialSpace(P1),
+                                TestSpace(TensorialVectorFunctionSpace(P2, single_component=c)),
                                 form=partial(gradient, component_index=c),
                                 type_descriptor=type_descriptor, opts=opts, geometries=div_geometries,
                                 blending=blending))
@@ -755,14 +756,14 @@ def all_operators(
             )
             # fmt: off
             ops.append(
-                OperatorInfo(mapping=f"P2", name=f"Epsilon_{r}_{c}",
-                             trial_space=TensorialVectorFunctionSpace(P2, single_component=c),
-                             test_space=TensorialVectorFunctionSpace(P2, single_component=r), form=p2_epsilon,
+                OperatorInfo(f"P2", f"Epsilon_{r}_{c}",
+                             TrialSpace(TensorialVectorFunctionSpace(P2, single_component=c)),
+                             TestSpace(TensorialVectorFunctionSpace(P2, single_component=r)), form=p2_epsilon,
                              type_descriptor=type_descriptor, geometries=list(geometries), opts=opts,
                              blending=blending))
-            ops.append(OperatorInfo(mapping=f"P2", name=f"FullStokes_{r}_{c}",
-                                    trial_space=TensorialVectorFunctionSpace(P2, single_component=c),
-                                    test_space=TensorialVectorFunctionSpace(P2, single_component=r),
+            ops.append(OperatorInfo(f"P2", f"FullStokes_{r}_{c}",
+                                    TrialSpace(TensorialVectorFunctionSpace(P2, single_component=c)),
+                                    TestSpace(TensorialVectorFunctionSpace(P2, single_component=r)),
                                     form=p2_full_stokes, type_descriptor=type_descriptor, geometries=list(geometries),
                                     opts=opts, blending=blending))
             # fmt: on
@@ -784,12 +785,12 @@ def all_operators(
         )
         # fmt: off
         ops.append(
-            OperatorInfo(mapping=f"P2", name=f"Epsilon_{r}_{c}", trial_space=TensorialVectorFunctionSpace(P2, single_component=c),
-                         test_space=TensorialVectorFunctionSpace(P2, single_component=r), form=p2_epsilon,
+            OperatorInfo(f"P2", f"Epsilon_{r}_{c}", TrialSpace(TensorialVectorFunctionSpace(P2, single_component=c)),
+                         TestSpace(TensorialVectorFunctionSpace(P2, single_component=r)), form=p2_epsilon,
                          type_descriptor=type_descriptor, geometries=three_d, opts=opts, blending=blending))
         ops.append(
-            OperatorInfo(mapping=f"P2", name=f"FullStokes_{r}_{c}", trial_space=TensorialVectorFunctionSpace(P2, single_component=c),
-                         test_space=TensorialVectorFunctionSpace(P2, single_component=r), form=p2_full_stokes,
+            OperatorInfo(f"P2", f"FullStokes_{r}_{c}", TrialSpace(TensorialVectorFunctionSpace(P2, single_component=c)),
+                         TestSpace(TensorialVectorFunctionSpace(P2, single_component=r)), form=p2_full_stokes,
                          type_descriptor=type_descriptor, geometries=three_d, opts=opts, blending=blending))
         # fmt: on
 
@@ -851,7 +852,6 @@ def generate_elementwise_op(
             )
 
         if op_info.form_boundary is not None:
-
             boundary_geometry: ElementGeometry
             if geometry == TriangleElement():
                 boundary_geometry = LineElement(space_dimension=2)
diff --git a/hog/fem_helpers.py b/hog/fem_helpers.py
index c50bed0..be301a8 100644
--- a/hog/fem_helpers.py
+++ b/hog/fem_helpers.py
@@ -32,7 +32,7 @@ from hog.element_geometry import (
     LineElement,
 )
 from hog.exception import HOGException
-from hog.function_space import FunctionSpace
+from hog.function_space import FunctionSpace, TrialSpace, TestSpace
 from hog.math_helpers import inv, det
 from hog.multi_assignment import MultiAssignment
 from hog.symbolizer import Symbolizer
@@ -52,7 +52,7 @@ from hog.dof_symbol import DoFSymbol
 
 
 def create_empty_element_matrix(
-    trial: FunctionSpace, test: FunctionSpace, geometry: ElementGeometry
+    trial: TrialSpace, test: TestSpace, geometry: ElementGeometry
 ) -> sp.Matrix:
     """
     Returns a sympy matrix of the required size corresponding to the trial and test spaces, initialized with zeros.
@@ -75,7 +75,7 @@ class ElementMatrixData:
 
 
 def element_matrix_iterator(
-    trial: FunctionSpace, test: FunctionSpace, geometry: ElementGeometry
+    trial: TrialSpace, test: TestSpace, geometry: ElementGeometry
 ) -> Iterator[ElementMatrixData]:
     """Call this to create a generator to conveniently fill the element matrix."""
     for row, (psi, grad_psi, hessian_psi) in enumerate(
diff --git a/hog/forms.py b/hog/forms.py
index 839da83..4b0d0ef 100644
--- a/hog/forms.py
+++ b/hog/forms.py
@@ -33,7 +33,7 @@ from hog.fem_helpers import (
     fem_function_on_element,
     fem_function_gradient_on_element,
 )
-from hog.function_space import FunctionSpace, N1E1Space
+from hog.function_space import FunctionSpace, N1E1Space, TrialSpace, TestSpace
 from hog.math_helpers import dot, inv, abs, det, double_contraction
 from hog.quadrature import Quadrature, Tabulation
 from hog.symbolizer import Symbolizer
@@ -43,8 +43,8 @@ from hog.integrand import process_integrand, Form
 
 
 def diffusion(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -84,14 +84,14 @@ Weak formulation
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def mass(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -118,14 +118,14 @@ Weak formulation
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def div_k_grad(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -164,14 +164,14 @@ Weak formulation
         symbolizer,
         blending=blending,
         fe_coefficients={"k": coefficient_function_space},
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def nonlinear_diffusion(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     coefficient_function_space: FunctionSpace,
@@ -197,7 +197,7 @@ Note: :math:`a(c) = 1/8 + u^2` is currently hard-coded and the form is intended
             "The nonlinear-diffusion form does currently not support blending."
         )
 
-    if trial != test:
+    if trial != test:  # type: ignore[comparison-overlap]
         raise HOGException(
             "Trial space must be equal to test space to assemble non-linear diffusion matrix."
         )
@@ -228,15 +228,15 @@ Note: :math:`a(c) = 1/8 + u^2` is currently hard-coded and the form is intended
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
         fe_coefficients={"u": coefficient_function_space},
     )
 
 
 def nonlinear_diffusion_newton_galerkin(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     coefficient_function_space: FunctionSpace,
@@ -257,7 +257,7 @@ Weak formulation
 
 Note: :math:`a(k) = 1/8 + k^2` is currently hard-coded and the form is intended for :math:`k = u`.
 """
-    if trial != test:
+    if trial != test:  # type: ignore[comparison-overlap]
         raise HOGException(
             "Trial space must be equal to test space to assemble diffusion matrix."
         )
@@ -312,8 +312,8 @@ Note: :math:`a(k) = 1/8 + k^2` is currently hard-coded and the form is intended
 
 
 def epsilon(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -322,7 +322,6 @@ def epsilon(
     variable_viscosity: bool = True,
     coefficient_function_space: Optional[FunctionSpace] = None,
 ) -> Form:
-
     docstring = f"""
 "Epsilon" operator.
 
@@ -361,15 +360,15 @@ where
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
         fe_coefficients={"mu": coefficient_function_space},
     )
 
 
 def k_mass(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -398,14 +397,14 @@ Weak formulation
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def pspg(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     quad: Quadrature,
     symbolizer: Symbolizer,
@@ -456,14 +455,14 @@ for details.
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def linear_form(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     quad: Quadrature,
     symbolizer: Symbolizer,
@@ -477,7 +476,7 @@ def linear_form(
     where psi a test function and k = k(x) a scalar, external function.
     """
 
-    if trial != test:
+    if trial != test:  # type: ignore[comparison-overlap]
         raise HOGException(
             "Trial space must be equal to test space to assemble linear form (jep this is weird, but linear forms are implemented as diagonal matrices)."
         )
@@ -534,14 +533,13 @@ def linear_form(
 
 
 def divergence(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
     component_index: int = 0,
 ) -> Form:
-
     docstring = f"""
 Divergence.
 
@@ -571,8 +569,8 @@ Weak formulation
 
 
 def gradient(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -607,8 +605,8 @@ def gradient(
 
 
 def full_stokes(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -656,15 +654,15 @@ where
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
         fe_coefficients={"mu": coefficient_function_space},
     )
 
 
 def shear_heating(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -764,14 +762,14 @@ The resulting matrix must be multiplied with a vector of ones to be used as the
             "wy": velocity_function_space,
             "wz": velocity_function_space,
         },
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def divdiv(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     component_trial: int,
     component_test: int,
     geometry: ElementGeometry,
@@ -809,14 +807,14 @@ Weak formulation
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def supg_diffusion(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     velocity_function_space: FunctionSpace,
@@ -927,8 +925,8 @@ Weak formulation
 
 
 def grad_rho_by_rho_dot_u(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -1023,17 +1021,16 @@ Weak formulation
                     )
                 mat[data.row, data.col] = form
 
-    return Form(mat, tabulation, symmetric=trial == test, docstring=docstring)
+    return Form(mat, tabulation, symmetric=trial == test, docstring=docstring)  # type: ignore[comparison-overlap]
 
 
 def zero_form(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
 ) -> Form:
-
     from hog.recipes.integrands.volume.zero import integrand
 
     return process_integrand(
@@ -1043,6 +1040,6 @@ def zero_form(
         geometry,
         symbolizer,
         blending=blending,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring="",
     )
diff --git a/hog/forms_boundary.py b/hog/forms_boundary.py
index 5d81376..49a810d 100644
--- a/hog/forms_boundary.py
+++ b/hog/forms_boundary.py
@@ -17,15 +17,15 @@
 
 from typing import Optional
 from hog.element_geometry import ElementGeometry
-from hog.function_space import FunctionSpace
+from hog.function_space import FunctionSpace, TrialSpace, TestSpace
 from hog.symbolizer import Symbolizer
 from hog.blending import GeometryMap, IdentityMap
 from hog.integrand import process_integrand, Form
 
 
 def mass_boundary(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     volume_geometry: ElementGeometry,
     boundary_geometry: ElementGeometry,
     symbolizer: Symbolizer,
@@ -54,14 +54,14 @@ Weak formulation
         symbolizer,
         blending=blending,
         boundary_geometry=boundary_geometry,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         docstring=docstring,
     )
 
 
 def freeslip_momentum_weak_boundary(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     volume_geometry: ElementGeometry,
     boundary_geometry: ElementGeometry,
     symbolizer: Symbolizer,
@@ -124,15 +124,15 @@ Geometry map: {blending}
         symbolizer,
         blending=blending,
         boundary_geometry=boundary_geometry,
-        is_symmetric=trial == test,
+        is_symmetric=trial == test,  # type: ignore[comparison-overlap]
         fe_coefficients={"mu": function_space_mu},
         docstring=docstring,
     )
 
 
 def freeslip_divergence_weak_boundary(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     volume_geometry: ElementGeometry,
     boundary_geometry: ElementGeometry,
     symbolizer: Symbolizer,
@@ -176,8 +176,8 @@ Geometry map: {blending}
 
 
 def freeslip_gradient_weak_boundary(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     volume_geometry: ElementGeometry,
     boundary_geometry: ElementGeometry,
     symbolizer: Symbolizer,
diff --git a/hog/forms_facets.py b/hog/forms_facets.py
index 5229ceb..daff842 100644
--- a/hog/forms_facets.py
+++ b/hog/forms_facets.py
@@ -33,7 +33,7 @@ from hog.external_functions import (
     ScalarVariableCoefficient2D,
     ScalarVariableCoefficient3D,
 )
-from hog.function_space import FunctionSpace
+from hog.function_space import FunctionSpace, TrialSpace, TestSpace
 from hog.math_helpers import (
     dot,
     inv,
@@ -51,7 +51,9 @@ sigma_0 = 3
 
 def _affine_element_vertices(
     volume_element_geometry: ElementGeometry, symbolizer: Symbolizer
-) -> Tuple[List[sp.Matrix], List[sp.Matrix], List[sp.Matrix], sp.Matrix, sp.Matrix, sp.Matrix]:
+) -> Tuple[
+    List[sp.Matrix], List[sp.Matrix], List[sp.Matrix], sp.Matrix, sp.Matrix, sp.Matrix
+]:
     """Helper function that returns the symbols of the affine points of two neighboring elements.
 
     Returns a tuple of
@@ -145,8 +147,7 @@ def trafo_ref_interface_to_ref_element(
         affine_points_E = affine_points_E2
         affine_point_E_opposite = affine_point_E2_opposite
     else:
-        raise HOGException(
-            "Invalid element type (should be 'inner' or 'outer')")
+        raise HOGException("Invalid element type (should be 'inner' or 'outer')")
 
     # First we compute the transformation from the interface reference space to the affine interface.
     trafo_ref_interface_to_affine_interface_E = trafo_ref_to_affine(
@@ -168,8 +169,8 @@ def trafo_ref_interface_to_ref_element(
 
 def stokes_p0_stabilization(
     interface_type: str,
-    test_element: FunctionSpace,
-    trial_element: FunctionSpace,
+    test_element: TestSpace,
+    trial_element: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
@@ -247,7 +248,6 @@ def stokes_p0_stabilization(
             level=logging.DEBUG,
         ):
             for data in it:
-
                 # TODO: fix this by introducing extra symbols for the shape functions
                 phi = data.trial_shape
                 psi = data.test_shape
@@ -274,28 +274,26 @@ def stokes_p0_stabilization(
                 gamma = 0.1
 
                 if interface_type == "inner":
-                    form = ((gamma * volume_interface) *
-                            phi * psi) * volume_interface
+                    form = ((gamma * volume_interface) * phi * psi) * volume_interface
 
                 elif interface_type == "outer":
-                    form = (-(gamma * volume_interface) *
-                            phi * psi) * volume_interface
+                    form = (-(gamma * volume_interface) * phi * psi) * volume_interface
 
-                mat[data.row, data.col] = facet_quad.integrate(form, symbolizer)[0].subs(
-                    reference_symbols[volume_element_geometry.dimensions - 1], 0
-                )
+                mat[data.row, data.col] = facet_quad.integrate(form, symbolizer)[
+                    0
+                ].subs(reference_symbols[volume_element_geometry.dimensions - 1], 0)
 
     return mat
 
 
 def diffusion_sip_facet(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
-    blending: GeometryMap = IdentityMap()
+    blending: GeometryMap = IdentityMap(),
 ) -> sp.Matrix:
     r"""
     Interface integrals for the symmetric interior penalty formulation for the (constant-coeff.) Laplacian.
@@ -427,15 +425,15 @@ def diffusion_sip_facet(
             level=logging.DEBUG,
         ):
             for data in it:
-
                 # TODO: fix this by introducing extra symbols for the shape functions
                 phi = data.trial_shape
                 psi = data.test_shape
                 grad_phi = data.trial_shape_grad
                 grad_psi = data.test_shape_grad
 
-                shape_symbols = ["xi_shape_0", "xi_shape_1",
-                                 "xi_shape_2"][:volume_element_geometry.dimensions]
+                shape_symbols = ["xi_shape_0", "xi_shape_1", "xi_shape_2"][
+                    : volume_element_geometry.dimensions
+                ]
                 phi = phi.subs(zip(reference_symbols, shape_symbols))
                 psi = psi.subs(zip(reference_symbols, shape_symbols))
                 grad_phi = grad_phi.subs(zip(reference_symbols, shape_symbols))
@@ -468,37 +466,35 @@ def diffusion_sip_facet(
                 if interface_type == "inner":
                     form = (
                         -0.5
-                        * dot(grad_psi*jac_affine_inv_E1, outward_normal)[0, 0]
+                        * dot(grad_psi * jac_affine_inv_E1, outward_normal)[0, 0]
                         * phi
                         - 0.5
-                        * dot(grad_phi*jac_affine_inv_E1, outward_normal)[0, 0]
+                        * dot(grad_phi * jac_affine_inv_E1, outward_normal)[0, 0]
                         * psi
-                        + (sigma_0 / volume_interface ** beta_0) * phi * psi
+                        + (sigma_0 / volume_interface**beta_0) * phi * psi
                     ) * volume_interface
 
                 elif interface_type == "outer":
                     form = (
                         0.5
-                        * dot(grad_psi*jac_affine_inv_E1, outward_normal)[0, 0]
+                        * dot(grad_psi * jac_affine_inv_E1, outward_normal)[0, 0]
                         * phi
                         - 0.5
-                        * dot(grad_phi*jac_affine_inv_E2, outward_normal)[0, 0]
+                        * dot(grad_phi * jac_affine_inv_E2, outward_normal)[0, 0]
                         * psi
-                        - (sigma_0 / volume_interface ** beta_0) * phi * psi
+                        - (sigma_0 / volume_interface**beta_0) * phi * psi
                     ) * volume_interface
 
                 elif interface_type == "dirichlet":
                     form = (
-                        -dot(grad_psi*jac_affine_inv_E1,
-                             outward_normal)[0, 0] * phi
-                        - dot(grad_phi*jac_affine_inv_E1, outward_normal)[0, 0]
-                        * psi
-                        + (4 * sigma_0 / volume_interface ** beta_0) * phi * psi
+                        -dot(grad_psi * jac_affine_inv_E1, outward_normal)[0, 0] * phi
+                        - dot(grad_phi * jac_affine_inv_E1, outward_normal)[0, 0] * psi
+                        + (4 * sigma_0 / volume_interface**beta_0) * phi * psi
                     ) * volume_interface
 
-                mat[data.row, data.col] = facet_quad.integrate(form, symbolizer)[0].subs(
-                    reference_symbols[volume_element_geometry.dimensions - 1], 0
-                )
+                mat[data.row, data.col] = facet_quad.integrate(form, symbolizer)[
+                    0
+                ].subs(reference_symbols[volume_element_geometry.dimensions - 1], 0)
 
     return mat
 
@@ -612,19 +608,16 @@ def diffusion_sip_rhs_dirichlet(
             coeff_class = ScalarVariableCoefficient2D
         elif isinstance(volume_element_geometry, TetrahedronElement):
             coeff_class = ScalarVariableCoefficient3D
-        g = coeff_class(sp.Symbol("g"), 0, *
-                        trafo_ref_interface_to_affine_interface)
+        g = coeff_class(sp.Symbol("g"), 0, *trafo_ref_interface_to_affine_interface)
 
         with TimedLogger(
             f"integrating {mat.shape[0] * mat.shape[1]} expressions",
             level=logging.DEBUG,
         ):
             for i in range(function_space.num_dofs(volume_element_geometry)):
-
                 # TODO: fix this by introducing extra symbols for the shape functions
                 phi = function_space.shape(volume_element_geometry)[i]
-                grad_phi = function_space.grad_shape(
-                    volume_element_geometry)[i]
+                grad_phi = function_space.grad_shape(volume_element_geometry)[i]
 
                 shape_symbols = ["xi_shape_0", "xi_shape_1"]
                 phi = phi.subs(zip(reference_symbols, shape_symbols))
@@ -643,9 +636,8 @@ def diffusion_sip_rhs_dirichlet(
                 form = (
                     1
                     * (
-                        -dot(jac_affine_inv_E1.T *
-                             grad_phi, outward_normal)[0, 0]
-                        + (4 * sigma_0 / volume_interface ** beta_0) * phi
+                        -dot(jac_affine_inv_E1.T * grad_phi, outward_normal)[0, 0]
+                        + (4 * sigma_0 / volume_interface**beta_0) * phi
                     )
                     * g
                     * volume_interface
diff --git a/hog/forms_facets_vectorial.py b/hog/forms_facets_vectorial.py
index db52401..5779447 100644
--- a/hog/forms_facets_vectorial.py
+++ b/hog/forms_facets_vectorial.py
@@ -31,7 +31,7 @@ from hog.external_functions import (
     ScalarVariableCoefficient2D,
     ScalarVariableCoefficient3D,
 )
-from hog.function_space import FunctionSpace
+from hog.function_space import FunctionSpace, TrialSpace, TestSpace
 from hog.math_helpers import (
     dot,
     inv,
@@ -52,8 +52,8 @@ from hog.function_space import EnrichedGalerkinFunctionSpace
 
 def diffusion_sip_facet_vectorial(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
@@ -227,8 +227,8 @@ def diffusion_sip_facet_vectorial(
 
 def diffusion_iip_facet_vectorial(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
@@ -382,8 +382,8 @@ def diffusion_iip_facet_vectorial(
 
 def divergence_facet_vectorial(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     transpose: bool,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
@@ -540,8 +540,8 @@ def symm_grad(grad, jac):
 
 def epsilon_sip_facet_vectorial(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
@@ -760,8 +760,8 @@ def epsilon_sip_facet_vectorial(
 
 def epsilon_nip_facet_vectorial(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
@@ -1092,8 +1092,8 @@ def epsilon_sip_rhs_dirichlet_vectorial(
 
 def diffusion_nip_facet_vectorial(
     interface_type: str,
-    test_element_1: FunctionSpace,
-    trial_element_2: FunctionSpace,
+    test_element_1: TestSpace,
+    trial_element_2: TrialSpace,
     volume_element_geometry: ElementGeometry,
     facet_quad: Quadrature,
     symbolizer: Symbolizer,
diff --git a/hog/forms_vectorial.py b/hog/forms_vectorial.py
index a49ce70..4f032e5 100644
--- a/hog/forms_vectorial.py
+++ b/hog/forms_vectorial.py
@@ -31,7 +31,13 @@ from hog.fem_helpers import (
     scalar_space_dependent_coefficient,
 )
 from hog.integrand import Form
-from hog.function_space import FunctionSpace, EnrichedGalerkinFunctionSpace, N1E1Space
+from hog.function_space import (
+    FunctionSpace,
+    EnrichedGalerkinFunctionSpace,
+    N1E1Space,
+    TrialSpace,
+    TestSpace,
+)
 from hog.math_helpers import inv, abs, det, double_contraction, dot, curl
 from hog.quadrature import Quadrature, Tabulation
 from hog.symbolizer import Symbolizer
@@ -40,8 +46,8 @@ from hog.sympy_extensions import fast_subs
 
 
 def diffusion_vectorial(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     quad: Quadrature,
     symbolizer: Symbolizer,
@@ -99,8 +105,8 @@ def diffusion_vectorial(
 
 
 def mass_vectorial(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     quad: Quadrature,
     symbolizer: Symbolizer,
@@ -149,8 +155,8 @@ def mass_vectorial(
 
 
 def mass_n1e1(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -195,8 +201,8 @@ def mass_n1e1(
 
 
 def divergence_vectorial(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     transpose: bool,
     geometry: ElementGeometry,
     quad: Quadrature,
@@ -264,8 +270,8 @@ def divergence_vectorial(
 
 
 def curl_curl(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -320,8 +326,8 @@ def curl_curl(
 
 
 def curl_curl_plus_mass(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -471,8 +477,8 @@ Strong formulation
 
 
 def linear_form_vectorial(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     quad: Quadrature,
     symbolizer: Symbolizer,
diff --git a/hog/function_space.py b/hog/function_space.py
index 0f4029c..1818c0f 100644
--- a/hog/function_space.py
+++ b/hog/function_space.py
@@ -14,8 +14,9 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
+from abc import ABC, abstractmethod
 from pyclbr import Function
-from typing import Any, List, Optional, Protocol, Union
+from typing import Any, List, NewType, Optional, Union
 import sympy as sp
 
 from hog.element_geometry import (
@@ -28,34 +29,40 @@ from hog.math_helpers import grad, hessian
 from hog.symbolizer import Symbolizer
 
 
-class FunctionSpace(Protocol):
+class FunctionSpace(ABC):
     """Representation of a finite element function space."""
 
     @property
+    @abstractmethod
     def family(self) -> str:
         """The common name of this FEM space."""
         ...
 
     @property
+    @abstractmethod
     def is_vectorial(self) -> bool:
         """Whether shape functions are scalar or vector valued."""
         ...
 
     @property
+    @abstractmethod
     def is_continuous(self) -> bool:
         """Whether functions in this space are continuous across elements."""
         ...
 
     @property
+    @abstractmethod
     def degree(self) -> int:
         """The polynomial degree of the shape functions."""
         ...
 
     @property
+    @abstractmethod
     def symbolizer(self) -> Symbolizer:
         """The symbolizer used to construct this object."""
         ...
 
+    @abstractmethod
     def shape(
         self,
         geometry: ElementGeometry,
@@ -125,6 +132,10 @@ class FunctionSpace(Protocol):
         return len(self.shape(geometry))
 
 
+TrialSpace = NewType("TrialSpace", FunctionSpace)
+TestSpace = NewType("TestSpace", FunctionSpace)
+
+
 class LagrangianFunctionSpace(FunctionSpace):
     """Representation of a finite element function spaces.
 
@@ -425,7 +436,6 @@ class TensorialVectorFunctionSpace(FunctionSpace):
         domain: str = "reference",
         dof_map: Optional[List[int]] = None,
     ) -> List[sp.MatrixBase]:
-
         dim = geometry.dimensions
 
         shape_functions = self._component_function_space.shape(
diff --git a/hog/hyteg_form_template.py b/hog/hyteg_form_template.py
index 6a2bccd..f0d4308 100644
--- a/hog/hyteg_form_template.py
+++ b/hog/hyteg_form_template.py
@@ -21,7 +21,7 @@ from hog.ast import Assignment, CodeBlock
 from hog.exception import HOGException
 from hog.quadrature import Quadrature
 from hog.symbolizer import Symbolizer
-from hog.function_space import (FunctionSpace, N1E1Space)
+from hog.function_space import N1E1Space, TrialSpace, TestSpace
 from hog.element_geometry import ElementGeometry
 from hog.code_generation import code_block_from_element_matrix
 from hog.multi_assignment import Member
@@ -91,8 +91,9 @@ class HyTeGIntegrator:
         )
         return info
 
-    def _setup_methods(self) -> Tuple[str, str, List[str], List[str], List[str], List[str], List[Member]]:
-
+    def _setup_methods(
+        self,
+    ) -> Tuple[str, str, List[str], List[str], List[str], List[str], List[Member]]:
         rows, cols = self.element_matrix.shape
 
         # read from input array of computational vertices
@@ -110,9 +111,7 @@ class HyTeGIntegrator:
         integrate_impl = {}
 
         for integrate_matrix_element in self.integrate_matrix_elements:
-
             if integrate_matrix_element[0] == "all":
-
                 method_name = "integrateAll"
                 element_matrix_sliced = self.element_matrix
                 cpp_override = True
@@ -121,15 +120,13 @@ class HyTeGIntegrator:
                 output_assignments = []
                 for row in range(rows):
                     for col in range(cols):
-                        lhs = self.symbolizer.output_element_matrix_access(
-                            row, col)
+                        lhs = self.symbolizer.output_element_matrix_access(row, col)
                         rhs = self.symbolizer.element_matrix_entry(row, col)
                         output_assignments.append(
                             Assignment(lhs, rhs, is_declaration=False)
                         )
 
             elif integrate_matrix_element[0] == "row":
-
                 integrate_row = integrate_matrix_element[1]
 
                 method_name = f"integrateRow{integrate_row}"
@@ -144,8 +141,7 @@ class HyTeGIntegrator:
                 output_assignments = []
                 for col in range(cols):
                     lhs = self.symbolizer.output_element_matrix_access(0, col)
-                    rhs = self.symbolizer.element_matrix_entry(
-                        integrate_row, col)
+                    rhs = self.symbolizer.element_matrix_entry(integrate_row, col)
                     output_assignments.append(
                         Assignment(lhs, rhs, is_declaration=False)
                     )
@@ -164,8 +160,7 @@ class HyTeGIntegrator:
             if self.not_implemented:
                 code_block_code = ""
             else:
-                code_block_code = "\n      ".join(
-                    code_block.to_code().splitlines())
+                code_block_code = "\n      ".join(code_block.to_code().splitlines())
 
             hyteg_matrix_type = f"Matrix< real_t, {output_rows}, {output_cols} >"
 
@@ -183,8 +178,7 @@ class HyTeGIntegrator:
                 return f"   void {prefix}{method_name}( const std::array< Point3D, {self.geometry.num_vertices} >& {'' if without_argnames else 'coords'}, {hyteg_matrix_type}& {'' if without_argnames else 'elMat'} ) const{override_str}"
 
             integrate_decl[integrate_matrix_element] = (
-                "   " +
-                "\n   ".join(self._docstring(code_block).splitlines()) + "\n"
+                "   " + "\n   ".join(self._docstring(code_block).splitlines()) + "\n"
             )
             integrate_decl[
                 integrate_matrix_element
@@ -213,8 +207,7 @@ class HyTeGIntegrator:
                 fd_code = "   " + "\n   ".join(fd.declaration().splitlines())
                 helper_methods_decl.append(fd_code)
                 fd_code = "   " + "\n   ".join(
-                    fd.implementation(
-                        name_prefix=self.class_name + "::").splitlines()
+                    fd.implementation(name_prefix=self.class_name + "::").splitlines()
                 )
                 helper_methods_impl.append(fd_code)
 
@@ -230,12 +223,11 @@ class HyTeGIntegrator:
 
 
 class HyTeGFormClass:
-
     def __init__(
         self,
         name: str,
-        trial: FunctionSpace,
-        test: FunctionSpace,
+        trial: TrialSpace,
+        test: TestSpace,
         integrators: List[HyTeGIntegrator],
         description: str = "",
     ):
@@ -266,22 +258,20 @@ class HyTeGFormClass:
         for f in self.integrators:
             members += f.members
 
-        members = sorted(set(members), key = lambda m: m.name_constructor)
+        members = sorted(set(members), key=lambda m: m.name_constructor)
 
         if not members:
             return ""
 
         default_constructor = f'{self.name}() {{ WALBERLA_ABORT("Not implemented."); }}'
 
-        ctr_prms = ", ".join(
-            [f"{m.dtype} {m.name_constructor}" for m in members])
+        ctr_prms = ", ".join([f"{m.dtype} {m.name_constructor}" for m in members])
 
         init_list = "\n   , ".join(
             [f"{m.name_member}({m.name_constructor})" for m in members]
         )
 
-        member_decl = "\n   ".join(
-            [f"{m.dtype} {m.name_member};" for m in members])
+        member_decl = "\n   ".join([f"{m.dtype} {m.name_member};" for m in members])
 
         constructor_string = f""" public:
 
@@ -298,7 +288,6 @@ class HyTeGFormClass:
         return constructor_string
 
     def to_code(self, header: bool = True) -> str:
-
         file_string = []
 
         if isinstance(self.trial, N1E1Space):
@@ -346,7 +335,6 @@ class HyTeGFormClass:
 
 
 class HyTeGForm:
-
     NAMESPACE_OPEN = "namespace hyteg {\nnamespace forms {"
 
     NAMESPACE_CLOSE = "} // namespace forms\n} // namespace hyteg"
@@ -354,8 +342,8 @@ class HyTeGForm:
     def __init__(
         self,
         name: str,
-        trial: FunctionSpace,
-        test: FunctionSpace,
+        trial: TrialSpace,
+        test: TestSpace,
         formClasses: List[HyTeGFormClass],
         description: str = "",
     ):
@@ -366,7 +354,6 @@ class HyTeGForm:
         self.description = description
 
     def to_code(self, header: bool = True) -> str:
-
         file_string = []
 
         if isinstance(self.trial, N1E1Space):
@@ -374,7 +361,9 @@ class HyTeGForm:
         elif self.trial.degree == self.test.degree:
             super_class = f"form_hyteg_base/P{self.trial.degree}FormHyTeG"
         else:
-            super_class = f"form_hyteg_base/P{self.trial.degree}ToP{self.test.degree}FormHyTeG"
+            super_class = (
+                f"form_hyteg_base/P{self.trial.degree}ToP{self.test.degree}FormHyTeG"
+            )
 
         if header:
             includes = "\n".join(
diff --git a/hog/integrand.py b/hog/integrand.py
index ce25acc..b6175a3 100644
--- a/hog/integrand.py
+++ b/hog/integrand.py
@@ -52,7 +52,7 @@ from dataclasses import dataclass, asdict, fields, field
 import sympy as sp
 
 from hog.exception import HOGException
-from hog.function_space import FunctionSpace
+from hog.function_space import FunctionSpace, TrialSpace, TestSpace
 from hog.element_geometry import ElementGeometry
 from hog.quadrature import Quadrature, Tabulation
 from hog.symbolizer import Symbolizer
@@ -216,8 +216,8 @@ class IntegrandSymbols:
 
 def process_integrand(
     integrand: Callable[..., Any],
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     volume_geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -352,7 +352,6 @@ def process_integrand(
         s.hessian_b = symbolizer.hessian_blending_map(volume_geometry.dimensions)
 
     if boundary_geometry is not None:
-
         if boundary_geometry.dimensions != boundary_geometry.space_dimension - 1:
             raise HOGException(
                 "Since you are integrating over a boundary, the boundary element's space dimension should be larger "
diff --git a/hog/manifold_forms.py b/hog/manifold_forms.py
index d6ccfc2..f3e87d0 100644
--- a/hog/manifold_forms.py
+++ b/hog/manifold_forms.py
@@ -27,7 +27,7 @@ from hog.fem_helpers import (
     create_empty_element_matrix,
     element_matrix_iterator,
 )
-from hog.function_space import FunctionSpace
+from hog.function_space import TrialSpace, TestSpace
 from hog.math_helpers import dot, inv, abs, det, double_contraction, e_vec
 from hog.quadrature import Tabulation
 from hog.symbolizer import Symbolizer
@@ -39,8 +39,8 @@ from hog.manifold_helpers import face_projection, embedded_normal
 
 
 def laplace_beltrami(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -59,7 +59,7 @@ Weak formulation
     ∫ ∇u · G^(-1) · ∇v · (det(G))^0.5
 """
 
-    if trial != test:
+    if trial != test:  # type: ignore[comparison-overlap]
         raise HOGException(
             "Trial space must be equal to test space to assemble laplace beltrami matrix."
         )
@@ -115,8 +115,8 @@ Weak formulation
 
 
 def manifold_mass(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -135,7 +135,7 @@ Weak formulation
     ∫ uv · (det(G))^0.5
 """
 
-    if trial != test:
+    if trial != test:  # type: ignore[comparison-overlap]
         raise HOGException(
             "Trial space must be equal to test space to assemble laplace beltrami matrix."
         )
@@ -186,8 +186,8 @@ Weak formulation
 
 
 def manifold_vector_mass(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -261,8 +261,8 @@ Weak formulation
 
 
 def manifold_normal_penalty(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -336,8 +336,8 @@ Weak formulation
 
 
 def manifold_divergence(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -409,8 +409,8 @@ Weak formulation
 
 
 def manifold_vector_divergence(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -500,8 +500,8 @@ Weak formulation
 
 
 def manifold_epsilon(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
@@ -596,8 +596,8 @@ Weak formulation
 
 
 def vector_laplace_beltrami(
-    trial: FunctionSpace,
-    test: FunctionSpace,
+    trial: TrialSpace,
+    test: TestSpace,
     geometry: ElementGeometry,
     symbolizer: Symbolizer,
     blending: GeometryMap = IdentityMap(),
diff --git a/hog/operator_generation/indexing.py b/hog/operator_generation/indexing.py
index b29e566..e33704e 100644
--- a/hog/operator_generation/indexing.py
+++ b/hog/operator_generation/indexing.py
@@ -22,7 +22,6 @@ import sympy as sp
 
 from hog.element_geometry import ElementGeometry, TriangleElement, TetrahedronElement
 from hog.exception import HOGException
-from hog.function_space import FunctionSpace, N1E1Space
 from hog.symbolizer import Symbolizer
 
 from pystencils.integer_functions import int_div
@@ -35,7 +34,6 @@ import sympy as sp
 
 from hog.element_geometry import ElementGeometry, TriangleElement, TetrahedronElement
 from hog.exception import HOGException
-from hog.function_space import FunctionSpace
 from hog.symbolizer import Symbolizer
 
 from pystencils.integer_functions import int_div
@@ -173,7 +171,6 @@ def num_microcells_per_cell(level: int) -> int:
 
 
 def linear_macro_cell_size(width: int) -> int:
-
     if USE_SYMPY_INT_DIV:
         return sympy_int_div((width + 2) * (width + 1) * width, 6)
     else:
diff --git a/hog/operator_generation/kernel_types.py b/hog/operator_generation/kernel_types.py
index 474a9b4..6cd3f5a 100644
--- a/hog/operator_generation/kernel_types.py
+++ b/hog/operator_generation/kernel_types.py
@@ -37,7 +37,7 @@ from hog.cpp_printing import (
 )
 from hog.element_geometry import ElementGeometry
 from hog.exception import HOGException
-from hog.function_space import FunctionSpace
+from hog.function_space import FunctionSpace, TrialSpace, TestSpace
 from hog.operator_generation.function_space_impls import FunctionSpaceImpl
 from hog.operator_generation.indexing import FaceType, CellType
 from hog.operator_generation.pystencils_extensions import create_generic_fields
@@ -213,18 +213,13 @@ class AssembleDiagonal(KernelType):
 class Assemble(KernelType):
     def __init__(
         self,
-        src_space: FunctionSpace,
-        dst_space: FunctionSpace,
+        src: FunctionSpaceImpl,
+        dst: FunctionSpaceImpl,
     ):
         self.result_prefix = "elMat_"
-        idx_t = HOGType("idx_t", np.int64)
 
-        self.src: FunctionSpaceImpl = FunctionSpaceImpl.create_impl(
-            src_space, "src", idx_t
-        )
-        self.dst: FunctionSpaceImpl = FunctionSpaceImpl.create_impl(
-            dst_space, "dst", idx_t
-        )
+        self.src = src
+        self.dst = dst
 
     def kernel_operation(
         self,
@@ -350,19 +345,24 @@ class KernelWrapperType(ABC):
 
     @property
     @abstractmethod
-    def kernel_type(self) -> KernelType: ...
+    def kernel_type(self) -> KernelType:
+        ...
 
     @abstractmethod
-    def includes(self) -> Set[str]: ...
+    def includes(self) -> Set[str]:
+        ...
 
     @abstractmethod
-    def base_classes(self) -> List[str]: ...
+    def base_classes(self) -> List[str]:
+        ...
 
     @abstractmethod
-    def wrapper_methods(self) -> List[CppMethod]: ...
+    def wrapper_methods(self) -> List[CppMethod]:
+        ...
 
     @abstractmethod
-    def member_variables(self) -> List[CppMemberVariable]: ...
+    def member_variables(self) -> List[CppMemberVariable]:
+        ...
 
     def substitute(self, subs: Mapping[str, object]) -> None:
         self._template = Template(self._template.safe_substitute(subs))
@@ -371,8 +371,8 @@ class KernelWrapperType(ABC):
 class ApplyWrapper(KernelWrapperType):
     def __init__(
         self,
-        src_space: FunctionSpace,
-        dst_space: FunctionSpace,
+        src_space: TrialSpace,
+        dst_space: TestSpace,
         type_descriptor: HOGType,
         dims: List[int] = [2, 3],
     ):
@@ -683,8 +683,8 @@ class AssembleDiagonalWrapper(KernelWrapperType):
 class AssembleWrapper(KernelWrapperType):
     def __init__(
         self,
-        src_space: FunctionSpace,
-        dst_space: FunctionSpace,
+        src_space: TrialSpace,
+        dst_space: TestSpace,
         type_descriptor: HOGType,
         dims: List[int] = [2, 3],
     ):
@@ -761,7 +761,7 @@ class AssembleWrapper(KernelWrapperType):
 
     @property
     def kernel_type(self) -> KernelType:
-        return Assemble(self.src.fe_space, self.dst.fe_space)
+        return Assemble(self.src, self.dst)
 
     def includes(self) -> Set[str]:
         return (
diff --git a/hog_tests/operator_generation/test_boundary_loop.py b/hog_tests/operator_generation/test_boundary_loop.py
index 34a504e..41e733f 100644
--- a/hog_tests/operator_generation/test_boundary_loop.py
+++ b/hog_tests/operator_generation/test_boundary_loop.py
@@ -17,22 +17,26 @@ import logging
 
 from sympy.core.cache import clear_cache
 
-from hog.blending import AnnulusMap, IcosahedralShellMap
-from hog.element_geometry import LineElement, TriangleElement, TetrahedronElement
-from hog.function_space import LagrangianFunctionSpace
+from hog.blending import AnnulusMap, GeometryMap, IcosahedralShellMap
+from hog.element_geometry import (
+    ElementGeometry,
+    LineElement,
+    TriangleElement,
+    TetrahedronElement,
+)
+from hog.function_space import LagrangianFunctionSpace, TrialSpace, TestSpace
 from hog.operator_generation.operators import (
     HyTeGElementwiseOperator,
 )
 from hog.symbolizer import Symbolizer
 from hog.quadrature import Quadrature, select_quadrule
-from hog.operator_generation.kernel_types import ApplyWrapper
+from hog.operator_generation.kernel_types import ApplyWrapper, KernelWrapperType
 from hog.operator_generation.types import hyteg_type
 from hog.forms_boundary import mass_boundary
 from hog.logger import TimedLogger
 
 
 def test_boundary_loop():
-
     # TimedLogger.set_log_level(logging.DEBUG)
 
     dims = [2, 3]
@@ -43,15 +47,15 @@ def test_boundary_loop():
 
     name = f"P2MassBoundary"
 
-    trial = LagrangianFunctionSpace(2, symbolizer)
-    test = LagrangianFunctionSpace(2, symbolizer)
+    trial = TrialSpace(LagrangianFunctionSpace(2, symbolizer))
+    test = TestSpace(LagrangianFunctionSpace(2, symbolizer))
 
     type_descriptor = hyteg_type()
 
-    kernel_types = [
+    kernel_types: list[KernelWrapperType] = [
         ApplyWrapper(
-            test,
             trial,
+            test,
             type_descriptor=type_descriptor,
             dims=dims,
         )
@@ -65,15 +69,12 @@ def test_boundary_loop():
     )
 
     for dim in dims:
-
         if dim == 2:
-
-            volume_geometry = TriangleElement()
-            boundary_geometry = LineElement(space_dimension=2)
-            blending = AnnulusMap()
+            volume_geometry: ElementGeometry = TriangleElement()
+            boundary_geometry: ElementGeometry = LineElement(space_dimension=2)
+            blending: GeometryMap = AnnulusMap()
 
         else:
-
             volume_geometry = TetrahedronElement()
             boundary_geometry = TriangleElement(space_dimension=3)
             blending = IcosahedralShellMap()
diff --git a/hog_tests/operator_generation/test_indexing.py b/hog_tests/operator_generation/test_indexing.py
index 852463b..1387e75 100644
--- a/hog_tests/operator_generation/test_indexing.py
+++ b/hog_tests/operator_generation/test_indexing.py
@@ -263,13 +263,13 @@ def test_micro_volume_to_volume_indices():
         geometry: ElementGeometry,
         level: int,
         indexing_info: IndexingInfo,
-        n_dofs_per_primitive,
+        n_dofs_per_primitive: int,
         primitive_type: Union[FaceType, CellType],
         primitive_index: Tuple[int, int, int],
         target_array_index: int,
         intra_primitive_index: int = 0,
         memory_layout: VolumeDoFMemoryLayout = VolumeDoFMemoryLayout.AoS,
-    ):
+    ) -> None:
         indexing_info.level = level
         dof_indices = indexing.micro_element_to_volume_indices(
             primitive_type, primitive_index, n_dofs_per_primitive, memory_layout
diff --git a/hog_tests/operator_generation/test_opgen_smoke.py b/hog_tests/operator_generation/test_opgen_smoke.py
index c3b50a3..1b99af1 100644
--- a/hog_tests/operator_generation/test_opgen_smoke.py
+++ b/hog_tests/operator_generation/test_opgen_smoke.py
@@ -18,8 +18,13 @@ from sympy.core.cache import clear_cache
 
 from hog.operator_generation.loop_strategies import CUBES
 from hog.operator_generation.optimizer import Opts
-from hog.element_geometry import LineElement, TriangleElement, TetrahedronElement
-from hog.function_space import LagrangianFunctionSpace
+from hog.element_geometry import (
+    ElementGeometry,
+    LineElement,
+    TriangleElement,
+    TetrahedronElement,
+)
+from hog.function_space import LagrangianFunctionSpace, TrialSpace, TestSpace
 from hog.operator_generation.operators import HyTeGElementwiseOperator
 from hog.symbolizer import Symbolizer
 from hog.quadrature import Quadrature, select_quadrule
@@ -27,7 +32,7 @@ from hog.forms import div_k_grad
 from hog.forms_boundary import mass_boundary
 from hog.operator_generation.kernel_types import ApplyWrapper, AssembleWrapper
 from hog.operator_generation.types import hyteg_type
-from hog.blending import AnnulusMap, IcosahedralShellMap
+from hog.blending import AnnulusMap, GeometryMap, IcosahedralShellMap
 
 
 def test_opgen_smoke():
@@ -53,22 +58,22 @@ def test_opgen_smoke():
 
     dims = [2]
 
-    trial = LagrangianFunctionSpace(2, symbolizer)
-    test = LagrangianFunctionSpace(2, symbolizer)
+    trial = TrialSpace(LagrangianFunctionSpace(2, symbolizer))
+    test = TestSpace(LagrangianFunctionSpace(2, symbolizer))
     coeff = LagrangianFunctionSpace(2, symbolizer)
 
     type_descriptor = hyteg_type()
 
     kernel_types = [
         ApplyWrapper(
-            test,
             trial,
+            test,
             type_descriptor=type_descriptor,
             dims=dims,
         ),
         AssembleWrapper(
-            test,
             trial,
+            test,
             type_descriptor=type_descriptor,
             dims=dims,
         ),
@@ -85,11 +90,10 @@ def test_opgen_smoke():
     opts_boundary = {Opts.MOVECONSTANTS}
 
     for d in dims:
-
         if d == 2:
-            volume_geometry = TriangleElement()
-            boundary_geometry = LineElement(space_dimension=2)
-            blending_map = AnnulusMap()
+            volume_geometry: ElementGeometry = TriangleElement()
+            boundary_geometry: ElementGeometry = LineElement(space_dimension=2)
+            blending_map: GeometryMap = AnnulusMap()
         else:
             volume_geometry = TetrahedronElement()
             boundary_geometry = TriangleElement(space_dimension=3)
diff --git a/hog_tests/test_diffusion.py b/hog_tests/test_diffusion.py
index b92b534..dc8d321 100644
--- a/hog_tests/test_diffusion.py
+++ b/hog_tests/test_diffusion.py
@@ -20,7 +20,7 @@ import logging
 from hog.blending import IdentityMap, ExternalMap
 from hog.element_geometry import TriangleElement, TetrahedronElement
 from hog.forms import diffusion
-from hog.function_space import LagrangianFunctionSpace
+from hog.function_space import LagrangianFunctionSpace, TrialSpace, TestSpace
 from hog.hyteg_form_template import HyTeGForm, HyTeGFormClass, HyTeGIntegrator
 from hog.quadrature import Quadrature, select_quadrule
 from hog.symbolizer import Symbolizer
@@ -40,7 +40,7 @@ def test_diffusion_p1_affine():
     symbolizer = Symbolizer()
 
     geometries = [TriangleElement(), TetrahedronElement()]
-    schemes = {TriangleElement() : 2, TetrahedronElement() : 2 }
+    schemes = {TriangleElement(): 2, TetrahedronElement(): 2}
     blending = IdentityMap()
 
     class_name = f"P1DiffusionAffine"
@@ -48,8 +48,8 @@ def test_diffusion_p1_affine():
     form_codes = []
 
     for geometry in geometries:
-        trial = LagrangianFunctionSpace(1, symbolizer)
-        test = LagrangianFunctionSpace(1, symbolizer)
+        trial = TrialSpace(LagrangianFunctionSpace(1, symbolizer))
+        test = TestSpace(LagrangianFunctionSpace(1, symbolizer))
         quad = Quadrature(select_quadrule(schemes[geometry], geometry), geometry)
 
         mat = diffusion(
@@ -90,8 +90,8 @@ def test_diffusion_p2_blending_2D():
 
     form_codes = []
 
-    trial = LagrangianFunctionSpace(2, symbolizer)
-    test = LagrangianFunctionSpace(2, symbolizer)
+    trial = TrialSpace(LagrangianFunctionSpace(2, symbolizer))
+    test = TestSpace(LagrangianFunctionSpace(2, symbolizer))
     schemes = {TriangleElement(): 4, TetrahedronElement(): 4}
 
     quad = Quadrature(select_quadrule(schemes[geometry], geometry), geometry)
diff --git a/hog_tests/test_function_spaces.py b/hog_tests/test_function_spaces.py
index 1ffa604..00bdc1e 100644
--- a/hog_tests/test_function_spaces.py
+++ b/hog_tests/test_function_spaces.py
@@ -16,7 +16,11 @@
 
 import sympy as sp
 from hog.element_geometry import TriangleElement
-from hog.function_space import LagrangianFunctionSpace, TensorialVectorFunctionSpace
+from hog.function_space import (
+    FunctionSpace,
+    LagrangianFunctionSpace,
+    TensorialVectorFunctionSpace,
+)
 from hog.symbolizer import Symbolizer
 from hog.exception import HOGException
 
@@ -27,7 +31,7 @@ def test_function_spaces():
 
     print()
 
-    f = LagrangianFunctionSpace(1, symbolizer)
+    f: FunctionSpace = LagrangianFunctionSpace(1, symbolizer)
     f_shape = f.shape(geometry)
     f_grad_shape = f.grad_shape(geometry)
     print(f)
diff --git a/hog_tests/test_pspg.py b/hog_tests/test_pspg.py
index 08ffcc5..671532e 100644
--- a/hog_tests/test_pspg.py
+++ b/hog_tests/test_pspg.py
@@ -20,7 +20,7 @@ import logging
 from hog.blending import IdentityMap
 from hog.element_geometry import TriangleElement, TetrahedronElement
 from hog.forms import pspg
-from hog.function_space import LagrangianFunctionSpace
+from hog.function_space import LagrangianFunctionSpace, TrialSpace, TestSpace
 from hog.hyteg_form_template import HyTeGForm, HyTeGFormClass, HyTeGIntegrator
 from hog.quadrature import Quadrature
 from hog.symbolizer import Symbolizer
@@ -44,8 +44,8 @@ def test_pspg_p1_affine():
     form_codes = []
 
     for geometry in geometries:
-        trial = LagrangianFunctionSpace(1, symbolizer)
-        test = LagrangianFunctionSpace(1, symbolizer)
+        trial = TrialSpace(LagrangianFunctionSpace(1, symbolizer))
+        test = TestSpace(LagrangianFunctionSpace(1, symbolizer))
         quad = Quadrature("exact", geometry)
 
         form = pspg(
@@ -58,7 +58,12 @@ def test_pspg_p1_affine():
         )
         form_codes.append(
             HyTeGIntegrator(
-                class_name, form.integrand, geometry, quad, symbolizer, integrate_rows=[0]
+                class_name,
+                form.integrand,
+                geometry,
+                quad,
+                symbolizer,
+                integrate_rows=[0],
             )
         )
 
diff --git a/hog_tests/test_quadrature.py b/hog_tests/test_quadrature.py
index 7574265..42ff338 100644
--- a/hog_tests/test_quadrature.py
+++ b/hog_tests/test_quadrature.py
@@ -14,7 +14,12 @@
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-from hog.element_geometry import LineElement, TriangleElement, TetrahedronElement
+from hog.element_geometry import (
+    ElementGeometry,
+    LineElement,
+    TriangleElement,
+    TetrahedronElement,
+)
 from hog.quadrature import Quadrature, select_quadrule
 from hog.exception import HOGException
 
@@ -22,10 +27,13 @@ from hog.exception import HOGException
 def test_smoke():
     """Just a brief test to see if the quadrature class does _something_."""
 
-    geometries = [TriangleElement(), TetrahedronElement()] # TODO fix quad for lines
+    geometries = [TriangleElement(), TetrahedronElement()]  # TODO fix quad for lines
 
     for geometry in geometries:
-        schemes = {TriangleElement(): "exact", TetrahedronElement(): "exact"}
+        schemes: dict[ElementGeometry, str | int] = {
+            TriangleElement(): "exact",
+            TetrahedronElement(): "exact",
+        }
 
         quad = Quadrature(select_quadrule(schemes[geometry], geometry), geometry)
         print("points", quad.points())
@@ -42,4 +50,5 @@ def test_smoke():
             print("points", quad.points())
             print("weights", quad.weights())
 
-test_smoke()
\ No newline at end of file
+
+test_smoke()
-- 
GitLab