From 48d20529787894db29781ebae55a860fb2b709c9 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 25 Nov 2019 15:45:41 +0100
Subject: [PATCH] Optionally generate PyTorchModule without Python bindings

---
 src/pystencils_autodiff/backends/astnodes.py   |  3 ++-
 .../backends/test_torch_native_compilation.py  | 18 ++++++++++--------
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 35287b8..3355271 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -74,7 +74,7 @@ class TorchModule(JinjaCppFile):
     def backend(self):
         return 'gpucuda' if self.is_cuda else 'c'
 
-    def __init__(self, module_name, kernel_asts):
+    def __init__(self, module_name, kernel_asts, with_python_bindings=True):
         """Create a C++ module with forward and optional backward_kernels
 
         :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
@@ -95,6 +95,7 @@ class TorchModule(JinjaCppFile):
             'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
                                                           [self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
                                                               for a in wrapper_functions])
+            if with_python_bindings else ''
         }
 
         super().__init__(ast_dict)
diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index bde5df9..44265fc 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -9,15 +9,15 @@ from os.path import dirname, isfile, join
 
 import numpy as np
 import pytest
+import sympy
 
 import pystencils
-import sympy
 from pystencils_autodiff import create_backward_assignments
 from pystencils_autodiff._file_io import write_cached_content
 from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
 
 torch = pytest.importorskip('torch')
-pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0,
+pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--version']) != 0,
                                 reason='torch compilation requires ninja')
 
 
@@ -78,7 +78,8 @@ def test_torch_native_compilation_cpu():
     assert 'call_backward' in dir(torch_extension)
 
 
-def test_pybind11_compilation_cpu():
+@pytest.mark.parametrize('with_python_bindings', ('with_python_bindings', False))
+def test_pybind11_compilation_cpu(with_python_bindings):
 
     pytest.importorskip('pybind11')
     pytest.importorskip('cppimport')
@@ -100,13 +101,14 @@ def test_pybind11_compilation_cpu():
     forward_ast.function_name = 'forward'
     backward_ast = pystencils.create_kernel(backward_assignments, target)
     backward_ast.function_name = 'backward'
-    module = PybindModule(module_name, [forward_ast, backward_ast])
+    module = PybindModule(module_name, [forward_ast, backward_ast], with_python_bindings=with_python_bindings)
     print(module)
 
-    pybind_extension = module.compile()
-    assert pybind_extension is not None
-    assert 'call_forward' in dir(pybind_extension)
-    assert 'call_backward' in dir(pybind_extension)
+    if with_python_bindings:
+        pybind_extension = module.compile()
+        assert pybind_extension is not None
+        assert 'call_forward' in dir(pybind_extension)
+        assert 'call_backward' in dir(pybind_extension)
 
 
 @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
-- 
GitLab