Skip to content
Snippets Groups Projects
Commit f379fb13 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Printing of Python bindings now working for Torch/pybind11

parent 324c3211
No related branches found
No related tags found
No related merge requests found
Pipeline #17374 failed
......@@ -11,12 +11,14 @@
from collections.abc import Iterable
from os.path import dirname, join
import jinja2
from pystencils.astnodes import (
DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol)
DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol,
FieldStrideSymbol, Node)
from pystencils_autodiff._file_io import _read_template_from_file
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
from pystencils_autodiff.framework_integration.astnodes import (
WrapperFunction, generate_kernel_call)
JinjaCppFile, WrapperFunction, generate_kernel_call)
# Torch
......@@ -63,7 +65,7 @@ class TorchModule(JinjaCppFile):
TEMPLATE = _read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp'))
DESTRUCTURING_CLASS = TorchTensorDestructuring
def __init__(self, kernel_asts):
def __init__(self, module_name, kernel_asts):
"""Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
......@@ -71,10 +73,12 @@ class TorchModule(JinjaCppFile):
"""
if not isinstance(kernel_asts, Iterable):
kernel_asts = [kernel_asts]
wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
ast_dict = {
'kernels': kernel_asts,
'kernel_wrappers': [self.generate_wrapper_function(k) for k in kernel_asts],
'kernel_wrappers': wrapper_functions,
'python_bindings': PybindPythonBindings(module_name, [PybindFunctionWrapping(a) for a in wrapper_functions])
}
super().__init__(ast_dict)
......@@ -87,6 +91,46 @@ class TorchModule(JinjaCppFile):
class TensorflowModule(TorchModule):
DESTRUCTURING_CLASS = TensorflowTensorDestructuring
def __init__(self, module_name, forward_to_backward_kernel_dict):
"""Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
:param backward_kernel_ast:
"""
self._forward_to_backward_dict = forward_to_backward_kernel_dict
kernel_asts = list(forward_to_backward_kernel_dict.values()) + list(forward_to_backward_kernel_dict.keys())
super().__init__(module_name, kernel_asts)
class PybindModule(TorchModule):
DESTRUCTURING_CLASS = PybindArrayDestructuring
class PybindPythonBindings(JinjaCppFile):
TEMPLATE = jinja2.Template("""PYBIND11_MODULE("{{ module_name }}", m)
{
{% for ast_node in module_contents -%}
{{ ast_node | indent(3,true) }}
{% endfor -%}
}
""")
def __init__(self, module_name, astnodes_to_wrap):
super().__init__({'module_name': module_name, 'module_contents': astnodes_to_wrap})
class PybindFunctionWrapping(JinjaCppFile):
TEMPLATE = jinja2.Template(
"""m.def("{{ python_name }}", &{{ cpp_name }}, {% for p in parameters -%}"{{ p }}"_a{{- ", " if not loop.last -}}{% endfor %});""" # noqa
)
required_global_declarations = ["using namespace pybind11::literals;"]
headers = ['<pybind11/pybind11.h>',
'<pybind11/stl.h>']
def __init__(self, function_node):
super().__init__({'python_name': function_node.function_name,
'cpp_name': function_node.function_name,
'parameters': [p.symbol.name for p in function_node.get_parameters()]
})
#include <cuda.h>
#include <vector>
// Most compilers don't care whether it's __restrict or __restrict__
#define RESTRICT __restrict__
{% for header in headers -%}
......
......@@ -9,6 +9,7 @@ Astnodes useful for the generations of C modules for frameworks (apart from waLB
waLBerla currently uses `pystencils-walberla <https://pypi.org/project/pystencils-walberla/>`_.
"""
import itertools
from collections.abc import Iterable
from typing import Any, List, Set
......@@ -232,7 +233,9 @@ class JinjaCppFile(Node):
@property
def args(self):
"""Returns all arguments/children of this node."""
return self.ast_dict.values()
ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, str))]
ast_nodes_iterables = [a for a in self.ast_dict.values() if not isinstance(a, (Node, str))]
return ast_nodes + list(itertools.chain.from_iterable(ast_nodes_iterables))
@property
def symbols_defined(self):
......@@ -252,9 +255,12 @@ class JinjaCppFile(Node):
def __str__(self):
assert self.TEMPLATE, f"Template of {self.__class__} must be set"
render_dict = {k: self._print(v) if not isinstance(v, Iterable) else [self._print(a) for a in v]
render_dict = {k: self._print(v)
if not isinstance(v, Iterable) or isinstance(v, str)
else [self._print(a) for a in v]
for k, v in self.ast_dict.items()}
render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
render_dict.update({"globals": pystencils.backends.cbackend.get_global_declarations(self)})
return self.TEMPLATE.render(render_dict)
......
......@@ -8,11 +8,12 @@
"""
import pytest
import sympy
import pystencils
from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule
from pystencils_autodiff.backends.astnodes import (
PybindFunctionWrapping, PybindModule, PybindPythonBindings, TensorflowModule, TorchModule)
try:
from pystencils.interpolation_astnodes import TextureCachedField
......@@ -22,6 +23,7 @@ except ImportError:
def test_module_printing():
module_name = "my_module"
for target in ('cpu', 'gpu'):
z, y, x = pystencils.fields("z, y, x: [2d]")
......@@ -36,20 +38,22 @@ def test_module_printing():
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = TorchModule([forward_ast, backward_ast])
module = TorchModule(module_name, [forward_ast, backward_ast])
print(module)
module = TensorflowModule({forward_ast: backward_ast})
module = TensorflowModule(module_name, {forward_ast: backward_ast})
print(module)
if target == 'cpu':
module = PybindModule([forward_ast, backward_ast])
module = PybindModule(module_name, [forward_ast, backward_ast])
print(module)
module = PybindModule(forward_ast)
module = PybindModule(module_name, forward_ast)
print(module)
def test_module_printing_parameter():
module_name = "Ololol"
for target in ('cpu', 'gpu'):
z, y, x = pystencils.fields("z, y, x: [20,40]")
......@@ -65,16 +69,16 @@ def test_module_printing_parameter():
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = TorchModule([forward_ast, backward_ast])
module = TorchModule(module_name, [forward_ast, backward_ast])
print(module)
module = TensorflowModule({forward_ast: backward_ast})
module = TensorflowModule(module_name, {forward_ast: backward_ast})
print(module)
if target == 'cpu':
module = PybindModule([forward_ast, backward_ast])
module = PybindModule(module_name, [forward_ast, backward_ast])
print(module)
module = PybindModule(forward_ast)
module = PybindModule(module_name, forward_ast)
print(module)
......@@ -94,7 +98,7 @@ def test_module_printing_globals():
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = TorchModule([forward_ast, backward_ast])
module = TorchModule("hallo", [forward_ast, backward_ast])
print(module)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment