Skip to content
Snippets Groups Projects
Commit 48d20529 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Optionally generate PyTorchModule without Python bindings

parent 977960a4
No related branches found
No related tags found
No related merge requests found
...@@ -74,7 +74,7 @@ class TorchModule(JinjaCppFile): ...@@ -74,7 +74,7 @@ class TorchModule(JinjaCppFile):
def backend(self): def backend(self):
return 'gpucuda' if self.is_cuda else 'c' return 'gpucuda' if self.is_cuda else 'c'
def __init__(self, module_name, kernel_asts): def __init__(self, module_name, kernel_asts, with_python_bindings=True):
"""Create a C++ module with forward and optional backward_kernels """Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
...@@ -95,6 +95,7 @@ class TorchModule(JinjaCppFile): ...@@ -95,6 +95,7 @@ class TorchModule(JinjaCppFile):
'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name, 'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
[self.PYTHON_FUNCTION_WRAPPING_CLASS(a) [self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
for a in wrapper_functions]) for a in wrapper_functions])
if with_python_bindings else ''
} }
super().__init__(ast_dict) super().__init__(ast_dict)
......
...@@ -9,15 +9,15 @@ from os.path import dirname, isfile, join ...@@ -9,15 +9,15 @@ from os.path import dirname, isfile, join
import numpy as np import numpy as np
import pytest import pytest
import sympy
import pystencils import pystencils
import sympy
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import write_cached_content from pystencils_autodiff._file_io import write_cached_content
from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
torch = pytest.importorskip('torch') torch = pytest.importorskip('torch')
pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0, pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--version']) != 0,
reason='torch compilation requires ninja') reason='torch compilation requires ninja')
...@@ -78,7 +78,8 @@ def test_torch_native_compilation_cpu(): ...@@ -78,7 +78,8 @@ def test_torch_native_compilation_cpu():
assert 'call_backward' in dir(torch_extension) assert 'call_backward' in dir(torch_extension)
def test_pybind11_compilation_cpu(): @pytest.mark.parametrize('with_python_bindings', ('with_python_bindings', False))
def test_pybind11_compilation_cpu(with_python_bindings):
pytest.importorskip('pybind11') pytest.importorskip('pybind11')
pytest.importorskip('cppimport') pytest.importorskip('cppimport')
...@@ -100,13 +101,14 @@ def test_pybind11_compilation_cpu(): ...@@ -100,13 +101,14 @@ def test_pybind11_compilation_cpu():
forward_ast.function_name = 'forward' forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward' backward_ast.function_name = 'backward'
module = PybindModule(module_name, [forward_ast, backward_ast]) module = PybindModule(module_name, [forward_ast, backward_ast], with_python_bindings=with_python_bindings)
print(module) print(module)
pybind_extension = module.compile() if with_python_bindings:
assert pybind_extension is not None pybind_extension = module.compile()
assert 'call_forward' in dir(pybind_extension) assert pybind_extension is not None
assert 'call_backward' in dir(pybind_extension) assert 'call_forward' in dir(pybind_extension)
assert 'call_backward' in dir(pybind_extension)
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment