Skip to content
Snippets Groups Projects
Commit 5782627b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make tensorflow compilation work (single) compilation

parent 3b2eb160
No related branches found
No related tags found
No related merge requests found
Pipeline #17924 failed
...@@ -19,8 +19,6 @@ from pystencils_autodiff.backends.python_bindings import ( ...@@ -19,8 +19,6 @@ from pystencils_autodiff.backends.python_bindings import (
from pystencils_autodiff.framework_integration.astnodes import ( from pystencils_autodiff.framework_integration.astnodes import (
DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call) DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
# Torch
class TorchTensorDestructuring(DestructuringBindingsForFieldClass): class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = { CLASS_TO_MEMBER_DICT = {
......
...@@ -101,8 +101,6 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho ...@@ -101,8 +101,6 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
parameters = function_node.get_parameters() parameters = function_node.get_parameters()
output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes
print([f for f in function_node.atoms(Node)])
docstring = "TODO" # TODO docstring = "TODO" # TODO
# this looks almost like lisp 😕 # this looks almost like lisp 😕
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
""" """
""" """
import json
import subprocess import subprocess
import sysconfig import sysconfig
from itertools import chain from itertools import chain
...@@ -24,11 +25,13 @@ if get_compiler_config()['os'] != 'windows': ...@@ -24,11 +25,13 @@ if get_compiler_config()['os'] != 'windows':
_output_flag = '-o' _output_flag = '-o'
_include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()] _include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()]
_do_not_link_flag = "-c" _do_not_link_flag = "-c"
_position_independent_flag = "-fPIC"
else: else:
_do_not_link_flag = "/c" _do_not_link_flag = "/c"
_output_flag = '/OUT:' _output_flag = '/OUT:'
_shared_object_flag = '/DLL' _shared_object_flag = '/DLL'
_include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()] _include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()]
_position_independent_flag = "/DTHIS_FLAG_DOES_NOTHING"
try: try:
...@@ -40,7 +43,7 @@ except ImportError: ...@@ -40,7 +43,7 @@ except ImportError:
pass pass
def link_and_load(object_files, destination_file=None, link_cudart=False, overwrite_destination_file=True): def link_and_load(object_files, destination_file=None, overwrite_destination_file=True):
"""Compiles given :param:`source_file` to a Tensorflow shared Library. """Compiles given :param:`source_file` to a Tensorflow shared Library.
.. warning:: .. warning::
...@@ -58,15 +61,13 @@ def link_and_load(object_files, destination_file=None, link_cudart=False, overwr ...@@ -58,15 +61,13 @@ def link_and_load(object_files, destination_file=None, link_cudart=False, overwr
if not exists(destination_file) or overwrite_destination_file: if not exists(destination_file) or overwrite_destination_file:
command = [get_compiler_config()['command'], command = [get_compiler_config()['command'],
*(get_compiler_config()['flags']).split(' '), _position_independent_flag,
*object_files, *object_files,
*_tf_link_flags, *_tf_link_flags,
*_tf_compile_flags,
*_include_flags, *_include_flags,
_shared_object_flag, _shared_object_flag,
_output_flag + destination_file] # /out: for msvc??? _output_flag,
if link_cudart: destination_file] # /out: for msvc???
command.append('-lcudart')
subprocess.check_call(command) subprocess.check_call(command)
...@@ -81,34 +82,46 @@ def try_get_cuda_arch_flag(): ...@@ -81,34 +82,46 @@ def try_get_cuda_arch_flag():
except Exception: except Exception:
arch = None arch = None
if arch: if arch:
return "-arch " + arch return "-arch=" + arch
else: else:
return None return None
_cuda_arch_flag = try_get_cuda_arch_flag() _cuda_arch_flag = try_get_cuda_arch_flag()
_nvcc_flags = ["-w", "-std=c++14", "-Wno-deprecated-gpu-targets"]
if _cuda_arch_flag:
_nvcc_flags.append(_cuda_arch_flag)
if pystencils.gpucuda.cudajit.USE_FAST_MATH:
_nvcc_flags.append('-use_fast_math')
def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True): def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True):
if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
write_file(pystencils.cpu.cpujit.get_configuration_file_path(), json.dumps(pystencils.cpu.cpujit._config))
destination_file = file + '.o' destination_file = file + '.o'
if use_nvcc: if use_nvcc:
command = [nvcc, command = [nvcc,
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'-ccbin', '-ccbin',
get_compiler_config()['command'], get_compiler_config()['tensorflow_host_compiler'],
*(get_compiler_config()['flags']).split(' '), '-Xcompiler',
get_compiler_config()['flags'].replace('c++11', 'c++14'),
*_nvcc_flags,
file, file,
'-x', '-x',
'cu', 'cu',
'-Xcompiler', '-Xcompiler',
'-fPIC', # TODO: msvc! _position_independent_flag,
_do_not_link_flag, _do_not_link_flag,
*_tf_compile_flags, *_tf_compile_flags,
*_include_flags, *_include_flags,
_output_flag + destination_file] _output_flag,
if _cuda_arch_flag: destination_file]
command.append(_cuda_arch_flag)
else: else:
command = [get_compiler_config()['command'], command = [get_compiler_config()['command'],
*(get_compiler_config()['flags']).split(' '), *(get_compiler_config()['flags']).split(' '),
...@@ -116,7 +129,9 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -116,7 +129,9 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
_do_not_link_flag, _do_not_link_flag,
*_tf_compile_flags, *_tf_compile_flags,
*_include_flags, *_include_flags,
_output_flag + destination_file] _output_flag,
destination_file]
if not exists(destination_file) or overwrite_destination_file: if not exists(destination_file) or overwrite_destination_file:
subprocess.check_call(command) subprocess.check_call(command)
return destination_file return destination_file
...@@ -126,7 +141,7 @@ def compile_sources_and_load(host_sources, cuda_sources=[]): ...@@ -126,7 +141,7 @@ def compile_sources_and_load(host_sources, cuda_sources=[]):
object_files = [] object_files = []
for source in tqdm(chain(host_sources, cuda_sources), desc='Compiling Tensorflow module...'): for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'):
is_cuda = source in cuda_sources is_cuda = source in cuda_sources
if exists(source): if exists(source):
...@@ -139,10 +154,10 @@ def compile_sources_and_load(host_sources, cuda_sources=[]): ...@@ -139,10 +154,10 @@ def compile_sources_and_load(host_sources, cuda_sources=[]):
write_file(file_name, source_code) write_file(file_name, source_code)
compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False) compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False)
object_files.append(file_name) object_files.append(file_name + '.o')
print('Linking Tensorflow module...') print('Linking Tensorflow module...')
module = link_and_load(object_files, overwrite_destination_file=False, link_cudart=cuda_sources or False) module = link_and_load(object_files, overwrite_destination_file=False)
if module: if module:
print('Loaded Tensorflow module') print('Loaded Tensorflow module.')
return module return module
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
import os import os
import subprocess import subprocess
import tempfile
from os.path import dirname, isfile, join from os.path import dirname, isfile, join
import pytest import pytest
import sympy import sympy
import tempfile
import pystencils import pystencils
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
...@@ -22,6 +22,7 @@ pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0, ...@@ -22,6 +22,7 @@ pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0,
PROJECT_ROOT = dirname PROJECT_ROOT = dirname
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
def test_torch_jit(): def test_torch_jit():
""" """
......
...@@ -39,8 +39,36 @@ def test_tensorflow_jit_cpu(): ...@@ -39,8 +39,36 @@ def test_tensorflow_jit_cpu():
backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward' backward_ast.function_name = 'backward'
module = TensorflowModule(module_name, [forward_ast, backward_ast]) module = TensorflowModule(module_name, [forward_ast, backward_ast])
print(module)
lib = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([str(module)]) lib = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([str(module)])
assert 'call_forward' in dir(lib) assert 'call_forward' in dir(lib)
assert 'call_backward' in dir(lib) assert 'call_backward' in dir(lib)
def test_tensorflow_jit_gpu():
pytest.importorskip('tensorflow')
module_name = "Ololol"
target = 'gpu'
z, y, x = pystencils.fields("z, y, x: [20,40]")
a = sympy.Symbol('a')
forward_assignments = pystencils.AssignmentCollection({
z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
})
backward_assignments = create_backward_assignments(forward_assignments)
forward_ast = pystencils.create_kernel(forward_assignments, target)
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = TensorflowModule(module_name, [forward_ast, backward_ast])
lib = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([], [str(module)])
assert 'call_forward' in dir(lib)
assert 'call_backward' in dir(lib)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment