-
Stephan Seitz authoredStephan Seitz authored
test_dynamic_function.py 3.08 KiB
import sympy as sp
import pystencils
from pystencils.data_types import TypedSymbol, create_type
from pystencils_autodiff.framework_integration.astnodes import DynamicFunction
from pystencils_autodiff.framework_integration.printer import (
DebugFrameworkPrinter, FrameworkIntegrationPrinter)
from pystencils_autodiff.framework_integration.types import TemplateType
def test_dynamic_function():
x, y = pystencils.fields('x, y: float32[3d]')
a = sp.symbols('a')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std::function<double(double)>'), create_type('double'), a)
assignments = pystencils.AssignmentCollection({
y.center: x.center + my_fun_call
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
template_fun_call = DynamicFunction(TypedSymbol('my_fun',
TemplateType('Functor_T')), create_type('double'), a, x.center)
assignments = pystencils.AssignmentCollection({
y.center: x.center + template_fun_call
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())
def test_dynamic_matrix():
x, y = pystencils.fields('x, y: float32[3d]')
from pystencils.data_types import TypedMatrixSymbol
a = sp.symbols('a')
A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std::function<Vector3<double>(double)>'), A.dtype, a)
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
def test_dynamic_matrix_location_dependent():
x, y = pystencils.fields('x, y: float32[3d]')
from pystencils.data_types import TypedMatrixSymbol
A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std: : function < Vector3 < double > (int, int, int) >'),
A.dtype,
*pystencils.x_vector(3))
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3))
assignments = pystencils.AssignmentCollection({
A: my_fun_call,
y.center: A[0] + A[1] + A[2]
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())