From 4261344e3eef95f768afc5a5865f51e9b63764ad Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 11 Feb 2020 18:17:58 +0100 Subject: [PATCH] Add test_dynamic_function, make DynamicFunction work with MatrixSymbols --- .../framework_integration/astnodes.py | 12 ++++- tests/test_dynamic_function.py | 51 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 8157386..4b1794e 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -305,7 +305,11 @@ class DynamicFunction(sp.Function): """ def __new__(cls, typed_function_symbol, return_dtype, *args): - return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args) + obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args) + if hasattr(return_dtype, 'shape'): + obj.shape = return_dtype.shape + + return obj @property def function_dtype(self): @@ -318,3 +322,9 @@ class DynamicFunction(sp.Function): @property def name(self): return self.args[0].name + + def __str__(self): + return f'{self.name}({", ".join(str(a) for a in self.args[2:])})' + + def __repr__(self): + return self.__str__() diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py index 0423116..4e266d1 100644 --- a/tests/test_dynamic_function.py +++ b/tests/test_dynamic_function.py @@ -33,3 +33,54 @@ def test_dynamic_function(): ast = pystencils.create_kernel(assignments) pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter()) + + +def test_dynamic_matrix(): + x, y = pystencils.fields('x, y: float32[3d]') + from pystencils.data_types import TypedMatrixSymbol + + a = sp.symbols('a') + + A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>') + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + 'std::function<Vector3<double>(double)>'), A.dtype, a) + + assignments = pystencils.AssignmentCollection({ + A: my_fun_call, + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + + +def test_dynamic_matrix_location_dependent(): + x, y = pystencils.fields('x, y: float32[3d]') + from pystencils.data_types import TypedMatrixSymbol + + A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>') + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + 'std: : function < Vector3 < double > (int, int, int) >'), + A.dtype, + *pystencils.x_vector(3)) + + assignments = pystencils.AssignmentCollection({ + A: my_fun_call, + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3)) + + assignments = pystencils.AssignmentCollection({ + A: my_fun_call, + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) -- GitLab