diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 815738611dd89a8d8b30ebdbb2ca177f4c7f144a..4b1794e5d9c16ea5278ee4198ceb3d58b11edb6b 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -305,7 +305,11 @@ class DynamicFunction(sp.Function): """ def __new__(cls, typed_function_symbol, return_dtype, *args): - return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args) + obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args) + if hasattr(return_dtype, 'shape'): + obj.shape = return_dtype.shape + + return obj @property def function_dtype(self): @@ -318,3 +322,9 @@ class DynamicFunction(sp.Function): @property def name(self): return self.args[0].name + + def __str__(self): + return f'{self.name}({", ".join(str(a) for a in self.args[2:])})' + + def __repr__(self): + return self.__str__() diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py index 0423116585c4801ccf13a483a24c72580961d5be..4e266d1675bd43b3a36318426e6e527e6ce0d477 100644 --- a/tests/test_dynamic_function.py +++ b/tests/test_dynamic_function.py @@ -33,3 +33,54 @@ def test_dynamic_function(): ast = pystencils.create_kernel(assignments) pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter()) + + +def test_dynamic_matrix(): + x, y = pystencils.fields('x, y: float32[3d]') + from pystencils.data_types import TypedMatrixSymbol + + a = sp.symbols('a') + + A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>') + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + 'std::function<Vector3<double>(double)>'), A.dtype, a) + + assignments = pystencils.AssignmentCollection({ + A: my_fun_call, + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + + +def test_dynamic_matrix_location_dependent(): + x, y = pystencils.fields('x, y: float32[3d]') + from pystencils.data_types import TypedMatrixSymbol + + A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>') + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + 'std: : function < Vector3 < double > (int, int, int) >'), + A.dtype, + *pystencils.x_vector(3)) + + assignments = pystencils.AssignmentCollection({ + A: my_fun_call, + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3)) + + assignments = pystencils.AssignmentCollection({ + A: my_fun_call, + y.center: A[0] + A[1] + A[2] + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())