diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 815738611dd89a8d8b30ebdbb2ca177f4c7f144a..4b1794e5d9c16ea5278ee4198ceb3d58b11edb6b 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -305,7 +305,11 @@ class DynamicFunction(sp.Function):
     """
 
     def __new__(cls, typed_function_symbol, return_dtype, *args):
-        return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
+        obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
+        if hasattr(return_dtype, 'shape'):
+            obj.shape = return_dtype.shape
+
+        return obj
 
     @property
     def function_dtype(self):
@@ -318,3 +322,9 @@ class DynamicFunction(sp.Function):
     @property
     def name(self):
         return self.args[0].name
+
+    def __str__(self):
+        return f'{self.name}({", ".join(str(a) for a in self.args[2:])})'
+
+    def __repr__(self):
+        return self.__str__()
diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py
index 0423116585c4801ccf13a483a24c72580961d5be..4e266d1675bd43b3a36318426e6e527e6ce0d477 100644
--- a/tests/test_dynamic_function.py
+++ b/tests/test_dynamic_function.py
@@ -33,3 +33,54 @@ def test_dynamic_function():
     ast = pystencils.create_kernel(assignments)
     pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
     pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())
+
+
+def test_dynamic_matrix():
+    x, y = pystencils.fields('x, y:  float32[3d]')
+    from pystencils.data_types import TypedMatrixSymbol
+
+    a = sp.symbols('a')
+
+    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              'std::function<Vector3<double>(double)>'), A.dtype, a)
+
+    assignments = pystencils.AssignmentCollection({
+        A:  my_fun_call,
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+
+
+def test_dynamic_matrix_location_dependent():
+    x, y = pystencils.fields('x, y:  float32[3d]')
+    from pystencils.data_types import TypedMatrixSymbol
+
+    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              'std: : function < Vector3 < double > (int, int, int) >'),
+                                  A.dtype,
+                                  *pystencils.x_vector(3))
+
+    assignments = pystencils.AssignmentCollection({
+        A:  my_fun_call,
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3))
+
+    assignments = pystencils.AssignmentCollection({
+        A:  my_fun_call,
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())