Skip to content
Snippets Groups Projects
Commit 798d2f80 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

tensorflow_jit: include compile commands into file hash

parent 5fae0e92
No related branches found
No related tags found
No related merge requests found
Pipeline #18029 passed
......@@ -13,10 +13,10 @@ import subprocess
import sysconfig
from os.path import exists, join
from tqdm import tqdm
import pystencils
from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path
from tqdm import tqdm
from pystencils_autodiff._file_io import read_file, write_file
_hash = hashlib.md5
......@@ -65,22 +65,20 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a
:returns: Object containing all Tensorflow Ops in that shared library.
"""
if not destination_file:
destination_file = join(get_cache_config()['object_cache'],
f"{_hash('.'.join(sorted(object_files)).encode()).hexdigest()}.so")
if not exists(destination_file) or overwrite_destination_file:
command = [get_compiler_config()['command'],
command_prefix = [get_compiler_config()['command'],
_position_independent_flag,
*object_files,
*_tf_link_flags,
*_include_flags,
*additional_link_flags,
_shared_object_flag,
_output_flag,
destination_file] # /out: for msvc???
_output_flag]
if not destination_file:
destination_file = join(get_cache_config()['object_cache'],
f"{_hash('.'.join(sorted(object_files + command_prefix)).encode()).hexdigest()}.so")
if not exists(destination_file) or overwrite_destination_file:
command = command_prefix + [destination_file]
subprocess.check_call(command, env=_compile_env)
return destination_file
......@@ -121,9 +119,8 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
destination_file = file + _object_file_extension
if use_nvcc:
command = [nvcc,
command_prefix = [nvcc,
'--expt-relaxed-constexpr',
'-ccbin',
get_compiler_config()['tensorflow_host_compiler'],
......@@ -139,20 +136,21 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
*_tf_compile_flags,
*_include_flags,
*additional_compile_flags,
_output_flag,
destination_file]
_output_flag]
else:
command = [get_compiler_config()['command'],
command_prefix = [get_compiler_config()['command'],
*(get_compiler_config()['flags']).split(' '),
file,
_do_not_link_flag,
*_tf_compile_flags,
*_include_flags,
*additional_compile_flags,
_output_flag,
destination_file]
_output_flag]
destination_file = f'{file}_{_hash(".".join(command_prefix).encode()).hexdigest()}.{_object_file_extension}'
if not exists(destination_file) or overwrite_destination_file:
command = command_prefix + [destination_file]
subprocess.check_call(command, env=_compile_env)
return destination_file
......@@ -209,11 +207,11 @@ def compile_sources_and_load(host_sources,
file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}')
write_file(file_name, source_code)
compile_file(file_name,
object_file = compile_file(file_name,
use_nvcc=is_cuda,
overwrite_destination_file=False,
additional_compile_flags=additional_compile_flags)
object_files.append(file_name + _object_file_extension)
object_files.append(object_file)
print('Linking Tensorflow module...')
module_file = link(object_files,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment