From 3b2eb1608fe94c3ff56cbe5c13ee78ddbd3c5b7f Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 12 Sep 2019 18:17:01 +0200
Subject: [PATCH] Update test_native_tensorflow_compilation

---
 tests/test_native_tensorflow_compilation.py | 20 +++++++++++++-------
 1 file changed, 13 insertions(+), 7 deletions(-)

diff --git a/tests/test_native_tensorflow_compilation.py b/tests/test_native_tensorflow_compilation.py
index 864c4a0..f8c5c8d 100644
--- a/tests/test_native_tensorflow_compilation.py
+++ b/tests/test_native_tensorflow_compilation.py
@@ -18,6 +18,7 @@ import pytest
 import sympy
 
 import pystencils
+from pystencils.cpu.cpujit import get_compiler_config
 from pystencils.include import get_pystencils_include_path
 from pystencils_autodiff import create_backward_assignments
 from pystencils_autodiff._file_io import write_file
@@ -111,17 +112,23 @@ def test_native_tensorflow_compilation_gpu():
     backward_ast = pystencils.create_kernel(backward_assignments, target)
     backward_ast.function_name = 'backward'
     module = TensorflowModule(module_name, [forward_ast, backward_ast])
-    print(module)
+    print(str(module))
 
     temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp')
     print(temp_file.name)
     write_file(temp_file.name, str(module))
+    if 'tensorflow_host_compiler' not in get_compiler_config():
+        get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
 
     # on my machine g++-6 and clang-7 are working
+    # '-ccbin',
+    # 'g++-6',
     command = ['nvcc',
                temp_file.name,
                '-lcudart',
                '--expt-relaxed-constexpr',
+               '-ccbin',
+               get_compiler_config()['tensorflow_host_compiler'],
                '-lcudart',
                '-std=c++14',
                '-x',
@@ -130,18 +137,17 @@ def test_native_tensorflow_compilation_gpu():
                '-fPIC',
                '-c',
                '-o',
-               'foo_gpu.o'] + link_flags + compile_flags + extra_flags
-    print(command)
+               'foo_gpu.o'] + compile_flags + extra_flags
+
     subprocess.check_call(command)
 
     # command = ['clang-7', '-shared', temp_file.name, '--cuda-path=/usr/include',  '-std=c++14',
     # '-fPIC', '-lcudart', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags
-    command = ['c++', '-std=c++14', '-fPIC', '-lcudart', 'foo_gpu.o',
-               '-shared', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags
-    print(command)
+    command = ['c++', '-fPIC', '-lcudart', 'foo_gpu.o',
+               '-shared', '-o', 'foo.so'] + link_flags
+
     subprocess.check_call(command)
     lib = tf.load_op_library(join(os.getcwd(), 'foo.so'))
 
-    print(dir(lib))
     assert 'call_forward' in dir(lib)
     assert 'call_backward' in dir(lib)
-- 
GitLab