Skip to content
Snippets Groups Projects
Commit 4261344e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test_dynamic_function, make DynamicFunction work with MatrixSymbols

parent f65844f4
Branches
Tags
No related merge requests found
Pipeline #21825 failed with stage
in 5 minutes and 4 seconds
......@@ -305,7 +305,11 @@ class DynamicFunction(sp.Function):
"""
def __new__(cls, typed_function_symbol, return_dtype, *args):
return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
obj = sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
if hasattr(return_dtype, 'shape'):
obj.shape = return_dtype.shape
return obj
@property
def function_dtype(self):
......@@ -318,3 +322,9 @@ class DynamicFunction(sp.Function):
@property
def name(self):
return self.args[0].name
def __str__(self):
return f'{self.name}({", ".join(str(a) for a in self.args[2:])})'
def __repr__(self):
return self.__str__()
......@@ -33,3 +33,54 @@ def test_dynamic_function():
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())
def test_dynamic_matrix():
x, y = pystencils.fields('x, y: float32[3d]')
from pystencils.data_types import TypedMatrixSymbol
a = sp.symbols('a')
A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std::function<Vector3<double>(double)>'), A.dtype, a)
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
def test_dynamic_matrix_location_dependent():
x, y = pystencils.fields('x, y: float32[3d]')
from pystencils.data_types import TypedMatrixSymbol
A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std: : function < Vector3 < double > (int, int, int) >'),
A.dtype,
*pystencils.x_vector(3))
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3))
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment