Skip to content
Snippets Groups Projects
Commit 48d20529 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Optionally generate PyTorchModule without Python bindings

parent 977960a4
No related branches found
No related tags found
No related merge requests found
......@@ -74,7 +74,7 @@ class TorchModule(JinjaCppFile):
def backend(self):
return 'gpucuda' if self.is_cuda else 'c'
def __init__(self, module_name, kernel_asts):
def __init__(self, module_name, kernel_asts, with_python_bindings=True):
"""Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
......@@ -95,6 +95,7 @@ class TorchModule(JinjaCppFile):
'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
[self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
for a in wrapper_functions])
if with_python_bindings else ''
}
super().__init__(ast_dict)
......
......@@ -9,15 +9,15 @@ from os.path import dirname, isfile, join
import numpy as np
import pytest
import sympy
import pystencils
import sympy
from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import write_cached_content
from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
torch = pytest.importorskip('torch')
pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0,
pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--version']) != 0,
reason='torch compilation requires ninja')
......@@ -78,7 +78,8 @@ def test_torch_native_compilation_cpu():
assert 'call_backward' in dir(torch_extension)
def test_pybind11_compilation_cpu():
@pytest.mark.parametrize('with_python_bindings', ('with_python_bindings', False))
def test_pybind11_compilation_cpu(with_python_bindings):
pytest.importorskip('pybind11')
pytest.importorskip('cppimport')
......@@ -100,13 +101,14 @@ def test_pybind11_compilation_cpu():
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = PybindModule(module_name, [forward_ast, backward_ast])
module = PybindModule(module_name, [forward_ast, backward_ast], with_python_bindings=with_python_bindings)
print(module)
pybind_extension = module.compile()
assert pybind_extension is not None
assert 'call_forward' in dir(pybind_extension)
assert 'call_backward' in dir(pybind_extension)
if with_python_bindings:
pybind_extension = module.compile()
assert pybind_extension is not None
assert 'call_forward' in dir(pybind_extension)
assert 'call_backward' in dir(pybind_extension)
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment