diff --git a/hog/forms.py b/hog/forms.py
index 847d0aeb41a1789ba693d6a19600a2870f633b10..482dcf86dcd2785f99d081265cd52b68a57fd7c8 100644
--- a/hog/forms.py
+++ b/hog/forms.py
@@ -2124,72 +2124,76 @@ where
                     "mu", geometry, symbolizer, blending=blending
                 )
 
-        ux, dof_symbols = fem_function_on_element(
-            coefficient_function_space,
+        phi_eval_symbols_u = tabulation.register_phi_evals(
+            velocity_function_space.shape(geometry)
+        )
+        ux, dof_symbols_ux = fem_function_on_element(
+            velocity_function_space,
             geometry,
             symbolizer,
             domain="reference",
             function_id="ux",
-            basis_eval=phi_eval_symbols,
+            basis_eval=phi_eval_symbols_u,
         )
 
         grad_ux, _ = fem_function_gradient_on_element(
-            coefficient_function_space,
+            velocity_function_space,
             geometry,
             symbolizer,
             domain="reference",
             function_id="grad_ux",
-            dof_symbols=dof_symbols,
+            dof_symbols=dof_symbols_ux,
         )
 
-        uy, dof_symbols = fem_function_on_element(
-            coefficient_function_space,
+        uy, dof_symbols_uy = fem_function_on_element(
+            velocity_function_space,
             geometry,
             symbolizer,
             domain="reference",
             function_id="uy",
-            basis_eval=phi_eval_symbols,
+            basis_eval=phi_eval_symbols_u,
         )
 
         grad_uy, _ = fem_function_gradient_on_element(
-            coefficient_function_space,
+            velocity_function_space,
             geometry,
             symbolizer,
             domain="reference",
             function_id="grad_uy",
-            dof_symbols=dof_symbols,
+            dof_symbols=dof_symbols_uy,
         )
 
         grad_ux_affine = jac_affine_inv.T * grad_ux
         grad_uy_affine = jac_affine_inv.T * grad_uy
 
-        grad_u_affine = grad_ux_affine.row_join(grad_uy_affine)
-
-        u = sp.Matrix([[ux], [uy]])
-
-        if geometry.dimensions > 2:
-            uz, dof_symbols = fem_function_on_element(
-                coefficient_function_space,
-                geometry,
-                symbolizer,
-                domain="reference",
-                function_id="uz",
-                basis_eval=phi_eval_symbols,
-            )
+        # if geometry.dimensions > 2:
+        uz, dof_symbols_uz = fem_function_on_element(
+            velocity_function_space,
+            geometry,
+            symbolizer,
+            domain="reference",
+            function_id="uz",
+            basis_eval=phi_eval_symbols_u,
+        )
 
-            grad_uz, _ = fem_function_gradient_on_element(
-                coefficient_function_space,
-                geometry,
-                symbolizer,
-                domain="reference",
-                function_id="grad_uz",
-                dof_symbols=dof_symbols,
-            )
-            grad_uz_affine = jac_affine_inv.T * grad_uz
+        grad_uz, _ = fem_function_gradient_on_element(
+            velocity_function_space,
+            geometry,
+            symbolizer,
+            domain="reference",
+            function_id="grad_uz",
+            dof_symbols=dof_symbols_uz,
+        )
 
-            grad_u_affine = grad_u_affine.row_join(grad_uz_affine)
+        grad_uz_affine = jac_affine_inv.T * grad_uz
+        
+        grad_u_affine = grad_ux_affine.row_join(grad_uy_affine)
 
+        if geometry.dimensions == 2:
+            u = sp.Matrix([[ux], [uy]])
+        elif geometry.dimensions == 3:
             u = sp.Matrix([[ux], [uy], [uz]])
+            grad_u_affine = grad_u_affine.row_join(grad_uz_affine)
 
         mat = create_empty_element_matrix(trial, test, geometry)
         it = element_matrix_iterator(trial, test, geometry)
@@ -2250,7 +2254,7 @@ where
                             double_contraction(2 * sym_grad_phi, sym_grad_psi)
                             - sp.Rational(2, 3) * divdiv
                         )
-                        * dot(jac_affine_inv_T_grad_psi, u)
+                        * dot(u, jac_affine_inv_T_grad_psi)
                         * jac_affine_det,
                     )
                 )[
@@ -2258,8 +2262,12 @@ where
                 ]
                 
                 form = (
-                    mu
-                    * contract_2_jac_affine_inv_sym_grad_phi_jac_affine_inv_sym_grad_psi__min_2third_divdiv_det_symbol
+                    mu * (
+                        double_contraction(2 * sym_grad_phi, sym_grad_psi)
+                        - sp.Rational(2, 3) * divdiv
+                    )
+                    * dot(u, jac_affine_inv_T_grad_psi)
+                    * jac_affine_det
                 )
 
             mat[data.row, data.col] = form