diff --git a/tests/test_native_tensorflow_compilation.py b/tests/test_native_tensorflow_compilation.py index 864c4a00bb423520e9bcd74cd9e2a7b7f9561eda..f8c5c8dbf6f4fe7d82c5b62892cd026acf907ecc 100644 --- a/tests/test_native_tensorflow_compilation.py +++ b/tests/test_native_tensorflow_compilation.py @@ -18,6 +18,7 @@ import pytest import sympy import pystencils +from pystencils.cpu.cpujit import get_compiler_config from pystencils.include import get_pystencils_include_path from pystencils_autodiff import create_backward_assignments from pystencils_autodiff._file_io import write_file @@ -111,17 +112,23 @@ def test_native_tensorflow_compilation_gpu(): backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast.function_name = 'backward' module = TensorflowModule(module_name, [forward_ast, backward_ast]) - print(module) + print(str(module)) temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp') print(temp_file.name) write_file(temp_file.name, str(module)) + if 'tensorflow_host_compiler' not in get_compiler_config(): + get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] # on my machine g++-6 and clang-7 are working + # '-ccbin', + # 'g++-6', command = ['nvcc', temp_file.name, '-lcudart', '--expt-relaxed-constexpr', + '-ccbin', + get_compiler_config()['tensorflow_host_compiler'], '-lcudart', '-std=c++14', '-x', @@ -130,18 +137,17 @@ def test_native_tensorflow_compilation_gpu(): '-fPIC', '-c', '-o', - 'foo_gpu.o'] + link_flags + compile_flags + extra_flags - print(command) + 'foo_gpu.o'] + compile_flags + extra_flags + subprocess.check_call(command) # command = ['clang-7', '-shared', temp_file.name, '--cuda-path=/usr/include', '-std=c++14', # '-fPIC', '-lcudart', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags - command = ['c++', '-std=c++14', '-fPIC', '-lcudart', 'foo_gpu.o', - '-shared', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags - print(command) + command = ['c++', '-fPIC', '-lcudart', 'foo_gpu.o', + '-shared', '-o', 'foo.so'] + link_flags + subprocess.check_call(command) lib = tf.load_op_library(join(os.getcwd(), 'foo.so')) - print(dir(lib)) assert 'call_forward' in dir(lib) assert 'call_backward' in dir(lib)