From 4261344e3eef95f768afc5a5865f51e9b63764ad Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 11 Feb 2020 18:17:58 +0100
Subject: [PATCH] Add test_dynamic_function, make DynamicFunction work with
 MatrixSymbols

---
 .../framework_integration/astnodes.py         | 12 ++++-
 tests/test_dynamic_function.py                | 51 +++++++++++++++++++
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 8157386..4b1794e 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -305,7 +305,11 @@ class DynamicFunction(sp.Function):
     """
 
     def __new__(cls, typed_function_symbol, return_dtype, *args):
-        return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
+        obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
+        if hasattr(return_dtype, 'shape'):
+            obj.shape = return_dtype.shape
+
+        return obj
 
     @property
     def function_dtype(self):
@@ -318,3 +322,9 @@ class DynamicFunction(sp.Function):
     @property
     def name(self):
         return self.args[0].name
+
+    def __str__(self):
+        return f'{self.name}({", ".join(str(a) for a in self.args[2:])})'
+
+    def __repr__(self):
+        return self.__str__()
diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py
index 0423116..4e266d1 100644
--- a/tests/test_dynamic_function.py
+++ b/tests/test_dynamic_function.py
@@ -33,3 +33,54 @@ def test_dynamic_function():
     ast = pystencils.create_kernel(assignments)
     pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
     pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())
+
+
+def test_dynamic_matrix():
+    x, y = pystencils.fields('x, y:  float32[3d]')
+    from pystencils.data_types import TypedMatrixSymbol
+
+    a = sp.symbols('a')
+
+    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              'std::function<Vector3<double>(double)>'), A.dtype, a)
+
+    assignments = pystencils.AssignmentCollection({
+        A:  my_fun_call,
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+
+
+def test_dynamic_matrix_location_dependent():
+    x, y = pystencils.fields('x, y:  float32[3d]')
+    from pystencils.data_types import TypedMatrixSymbol
+
+    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              'std: : function < Vector3 < double > (int, int, int) >'),
+                                  A.dtype,
+                                  *pystencils.x_vector(3))
+
+    assignments = pystencils.AssignmentCollection({
+        A:  my_fun_call,
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3))
+
+    assignments = pystencils.AssignmentCollection({
+        A:  my_fun_call,
+        y.center: A[0] + A[1] + A[2]
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
-- 
GitLab