Skip to content
Snippets Groups Projects
Commit a18f9171 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add python bindinds for Torch/Pybind11

parent 07b58488
Branches
Tags
No related merge requests found
......@@ -11,14 +11,14 @@
from collections.abc import Iterable
from os.path import dirname, join
import jinja2
from pystencils.astnodes import (
DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol,
FieldStrideSymbol, Node)
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff._file_io import _read_template_from_file
from pystencils_autodiff.backends.python_bindings import (
PybindFunctionWrapping, PybindPythonBindings, TorchPythonBindings)
from pystencils_autodiff.framework_integration.astnodes import (
JinjaCppFile, WrapperFunction, generate_kernel_call)
from pystencils_autodiff.framework_integration.astnodes import DestructuringBindingsForFieldClass
# Torch
......@@ -64,6 +64,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
class TorchModule(JinjaCppFile):
TEMPLATE = _read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp'))
DESTRUCTURING_CLASS = TorchTensorDestructuring
PYTHON_BINDINGS_CLASS = TorchPythonBindings
def __init__(self, module_name, kernel_asts):
"""Create a C++ module with forward and optional backward_kernels
......@@ -78,7 +79,8 @@ class TorchModule(JinjaCppFile):
ast_dict = {
'kernels': kernel_asts,
'kernel_wrappers': wrapper_functions,
'python_bindings': PybindPythonBindings(module_name, [PybindFunctionWrapping(a) for a in wrapper_functions])
'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
[PybindFunctionWrapping(a) for a in wrapper_functions])
}
super().__init__(ast_dict)
......@@ -105,32 +107,4 @@ class TensorflowModule(TorchModule):
class PybindModule(TorchModule):
DESTRUCTURING_CLASS = PybindArrayDestructuring
class PybindPythonBindings(JinjaCppFile):
TEMPLATE = jinja2.Template("""PYBIND11_MODULE("{{ module_name }}", m)
{
{% for ast_node in module_contents -%}
{{ ast_node | indent(3,true) }}
{% endfor -%}
}
""")
def __init__(self, module_name, astnodes_to_wrap):
super().__init__({'module_name': module_name, 'module_contents': astnodes_to_wrap})
class PybindFunctionWrapping(JinjaCppFile):
TEMPLATE = jinja2.Template(
"""m.def("{{ python_name }}", &{{ cpp_name }}, {% for p in parameters -%}"{{ p }}"_a{{- ", " if not loop.last -}}{% endfor %});""" # noqa
)
required_global_declarations = ["using namespace pybind11::literals;"]
headers = ['<pybind11/pybind11.h>',
'<pybind11/stl.h>']
def __init__(self, function_node):
super().__init__({'python_name': function_node.function_name,
'cpp_name': function_node.function_name,
'parameters': [p.symbol.name for p in function_node.get_parameters()]
})
PYTHON_BINDINGS_CLASS = PybindPythonBindings
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import jinja2
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
class PybindPythonBindings(JinjaCppFile):
TEMPLATE = jinja2.Template("""PYBIND11_MODULE("{{ module_name }}", m)
{
{% for ast_node in module_contents -%}
{{ ast_node | indent(3,true) }}
{% endfor -%}
}
""")
def __init__(self, module_name, astnodes_to_wrap):
super().__init__({'module_name': module_name, 'module_contents': astnodes_to_wrap})
class TorchPythonBindings(JinjaCppFile):
TEMPLATE = jinja2.Template("""PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
{% for ast_node in module_contents -%}
{{ ast_node | indent(3,true) }}
{% endfor -%}
}
""")
headers = ['<torch/extension.h>']
def __init__(self, module_name, astnodes_to_wrap):
super().__init__({'module_contents': astnodes_to_wrap})
class PybindFunctionWrapping(JinjaCppFile):
TEMPLATE = jinja2.Template(
"""m.def("{{ python_name }}", &{{ cpp_name }}, {% for p in parameters -%}"{{ p }}"_a{{- ", " if not loop.last -}}{% endfor %});""" # noqa
)
required_global_declarations = ["using namespace pybind11::literals;"]
headers = ['<pybind11/pybind11.h>',
'<pybind11/stl.h>']
def __init__(self, function_node):
super().__init__({'python_name': function_node.function_name,
'cpp_name': function_node.function_name,
'parameters': [p.symbol.name for p in function_node.get_parameters()]
})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment