Skip to content
Snippets Groups Projects
test_dynamic_function.py 3.08 KiB
import sympy as sp

import pystencils
from pystencils.data_types import TypedSymbol, create_type
from pystencils_autodiff.framework_integration.astnodes import DynamicFunction
from pystencils_autodiff.framework_integration.printer import (
    DebugFrameworkPrinter, FrameworkIntegrationPrinter)
from pystencils_autodiff.framework_integration.types import TemplateType


def test_dynamic_function():
    x, y = pystencils.fields('x, y:  float32[3d]')

    a = sp.symbols('a')

    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
                                              'std::function<double(double)>'), create_type('double'), a)

    assignments = pystencils.AssignmentCollection({
        y.center: x.center + my_fun_call
    })

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())

    template_fun_call = DynamicFunction(TypedSymbol('my_fun',
                                                    TemplateType('Functor_T')), create_type('double'), a, x.center)

    assignments = pystencils.AssignmentCollection({
        y.center: x.center + template_fun_call
    })

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
    pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())


def test_dynamic_matrix():
    x, y = pystencils.fields('x, y:  float32[3d]')
    from pystencils.data_types import TypedMatrixSymbol

    a = sp.symbols('a')

    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')

    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
                                              'std::function<Vector3<double>(double)>'), A.dtype, a)

    assignments = pystencils.AssignmentCollection({
        A:  my_fun_call,
        y.center: A[0] + A[1] + A[2]
    })

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())


def test_dynamic_matrix_location_dependent():
    x, y = pystencils.fields('x, y:  float32[3d]')
    from pystencils.data_types import TypedMatrixSymbol

    A = TypedMatrixSymbol('A', 3, 1, create_type('double'), 'Vector3<double>')

    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
                                              'std: : function < Vector3 < double > (int, int, int) >'),
                                  A.dtype,
                                  *pystencils.x_vector(3))

    assignments = pystencils.AssignmentCollection({
        A:  my_fun_call,
        y.center: A[0] + A[1] + A[2]
    })

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())

    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
                                              TemplateType('Functor_T')), A.dtype, *pystencils.x_vector(3))

    assignments = pystencils.AssignmentCollection({
        A:  my_fun_call,
        y.center: A[0] + A[1] + A[2]
    })

    ast = pystencils.create_kernel(assignments)
    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())