Skip to content
Snippets Groups Projects
Commit 3b2eb160 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Update test_native_tensorflow_compilation

parent bce14b17
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ import pytest ...@@ -18,6 +18,7 @@ import pytest
import sympy import sympy
import pystencils import pystencils
from pystencils.cpu.cpujit import get_compiler_config
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import write_file from pystencils_autodiff._file_io import write_file
...@@ -111,17 +112,23 @@ def test_native_tensorflow_compilation_gpu(): ...@@ -111,17 +112,23 @@ def test_native_tensorflow_compilation_gpu():
backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward' backward_ast.function_name = 'backward'
module = TensorflowModule(module_name, [forward_ast, backward_ast]) module = TensorflowModule(module_name, [forward_ast, backward_ast])
print(module) print(str(module))
temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp') temp_file = tempfile.NamedTemporaryFile(suffix='.cu' if target == 'gpu' else '.cpp')
print(temp_file.name) print(temp_file.name)
write_file(temp_file.name, str(module)) write_file(temp_file.name, str(module))
if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
# on my machine g++-6 and clang-7 are working # on my machine g++-6 and clang-7 are working
# '-ccbin',
# 'g++-6',
command = ['nvcc', command = ['nvcc',
temp_file.name, temp_file.name,
'-lcudart', '-lcudart',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'-ccbin',
get_compiler_config()['tensorflow_host_compiler'],
'-lcudart', '-lcudart',
'-std=c++14', '-std=c++14',
'-x', '-x',
...@@ -130,18 +137,17 @@ def test_native_tensorflow_compilation_gpu(): ...@@ -130,18 +137,17 @@ def test_native_tensorflow_compilation_gpu():
'-fPIC', '-fPIC',
'-c', '-c',
'-o', '-o',
'foo_gpu.o'] + link_flags + compile_flags + extra_flags 'foo_gpu.o'] + compile_flags + extra_flags
print(command)
subprocess.check_call(command) subprocess.check_call(command)
# command = ['clang-7', '-shared', temp_file.name, '--cuda-path=/usr/include', '-std=c++14', # command = ['clang-7', '-shared', temp_file.name, '--cuda-path=/usr/include', '-std=c++14',
# '-fPIC', '-lcudart', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags # '-fPIC', '-lcudart', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags
command = ['c++', '-std=c++14', '-fPIC', '-lcudart', 'foo_gpu.o', command = ['c++', '-fPIC', '-lcudart', 'foo_gpu.o',
'-shared', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags '-shared', '-o', 'foo.so'] + link_flags
print(command)
subprocess.check_call(command) subprocess.check_call(command)
lib = tf.load_op_library(join(os.getcwd(), 'foo.so')) lib = tf.load_op_library(join(os.getcwd(), 'foo.so'))
print(dir(lib))
assert 'call_forward' in dir(lib) assert 'call_forward' in dir(lib)
assert 'call_backward' in dir(lib) assert 'call_backward' in dir(lib)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment