diff --git a/hog/fem_helpers.py b/hog/fem_helpers.py index be301a8f6e287c26219deca9179e2c34501efcea..855bdd9ef3443249b0b20d6b4c3213e21b809a15 100644 --- a/hog/fem_helpers.py +++ b/hog/fem_helpers.py @@ -468,10 +468,13 @@ def fem_function_on_element( domain == "reference" ), "Tabulating the basis evaluation not implemented for affine domain." + rows = geometry.dimensions if function_space.is_vectorial else 1 + if domain == "reference": # On the reference domain, the reference coordinates symbols can be used directly, so no substitution # has to be performed for the shape functions. - s = sp.zeros(1, 1) + + s = sp.zeros(rows, 1) for dof, phi in zip( dofs, ( @@ -487,7 +490,7 @@ def fem_function_on_element( # On the affine / computational domain, the evaluation point is first mapped to reference space and then # the reference space coordinate symbols are substituted with the transformed point. eval_point_on_ref = trafo_affine_point_to_ref(geometry, symbolizer=symbolizer) - s = sp.zeros(1, 1) + s = sp.zeros(rows, 1) for dof, phi in zip( dofs, function_space.shape( @@ -523,7 +526,7 @@ def fem_function_gradient_on_element( dof_map: Optional[List[int]] = None, basis_eval: Union[str, List[sp.Expr]] = "default", dof_symbols: Optional[List[DoFSymbol]] = None, -) -> sp.Matrix: +) -> Tuple[sp.Matrix, List[DoFSymbol]]: """Returns an expression that is the gradient of the element-local polynomial, either in affine or reference coordinates. The expression is build using DoFSymbol instances so that the DoFs can be resolved later. @@ -553,7 +556,8 @@ def fem_function_gradient_on_element( if domain == "reference": # On the reference domain, the reference coordinates symbols can be used directly, so no substitution # has to be performed for the shape functions. - s = sp.zeros(geometry.dimensions, 1) + cols = geometry.dimensions if function_space.is_vectorial else 1 + s = sp.zeros(geometry.dimensions, cols) for dof, grad_phi in zip( dofs, ( @@ -571,4 +575,4 @@ def fem_function_gradient_on_element( raise HOGException( f"Invalid domain '{domain}': cannot evaluate local polynomial here." ) - return s, dofs + return sp.Matrix(s), dofs