From 48d20529787894db29781ebae55a860fb2b709c9 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 25 Nov 2019 15:45:41 +0100 Subject: [PATCH] Optionally generate PyTorchModule without Python bindings --- src/pystencils_autodiff/backends/astnodes.py | 3 ++- .../backends/test_torch_native_compilation.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 35287b8..3355271 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -74,7 +74,7 @@ class TorchModule(JinjaCppFile): def backend(self): return 'gpucuda' if self.is_cuda else 'c' - def __init__(self, module_name, kernel_asts): + def __init__(self, module_name, kernel_asts, with_python_bindings=True): """Create a C++ module with forward and optional backward_kernels :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) @@ -95,6 +95,7 @@ class TorchModule(JinjaCppFile): 'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name, [self.PYTHON_FUNCTION_WRAPPING_CLASS(a) for a in wrapper_functions]) + if with_python_bindings else '' } super().__init__(ast_dict) diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index bde5df9..44265fc 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -9,15 +9,15 @@ from os.path import dirname, isfile, join import numpy as np import pytest +import sympy import pystencils -import sympy from pystencils_autodiff import create_backward_assignments from pystencils_autodiff._file_io import write_cached_content from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule torch = pytest.importorskip('torch') -pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0, +pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--version']) != 0, reason='torch compilation requires ninja') @@ -78,7 +78,8 @@ def test_torch_native_compilation_cpu(): assert 'call_backward' in dir(torch_extension) -def test_pybind11_compilation_cpu(): +@pytest.mark.parametrize('with_python_bindings', ('with_python_bindings', False)) +def test_pybind11_compilation_cpu(with_python_bindings): pytest.importorskip('pybind11') pytest.importorskip('cppimport') @@ -100,13 +101,14 @@ def test_pybind11_compilation_cpu(): forward_ast.function_name = 'forward' backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast.function_name = 'backward' - module = PybindModule(module_name, [forward_ast, backward_ast]) + module = PybindModule(module_name, [forward_ast, backward_ast], with_python_bindings=with_python_bindings) print(module) - pybind_extension = module.compile() - assert pybind_extension is not None - assert 'call_forward' in dir(pybind_extension) - assert 'call_backward' in dir(pybind_extension) + if with_python_bindings: + pybind_extension = module.compile() + assert pybind_extension is not None + assert 'call_forward' in dir(pybind_extension) + assert 'call_backward' in dir(pybind_extension) @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") -- GitLab