diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 3a0fcbb7cae2e87836caad9d4b60bf286803966d..ab1a05eb7177757d35412335f861819b1451126b 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -13,10 +13,10 @@ import subprocess import sysconfig from os.path import exists, join -from tqdm import tqdm - import pystencils from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path +from tqdm import tqdm + from pystencils_autodiff._file_io import read_file, write_file _hash = hashlib.md5 @@ -65,22 +65,20 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a :returns: Object containing all Tensorflow Ops in that shared library. """ - + command_prefix = [get_compiler_config()['command'], + _position_independent_flag, + *object_files, + *_tf_link_flags, + *_include_flags, + *additional_link_flags, + _shared_object_flag, + _output_flag] if not destination_file: destination_file = join(get_cache_config()['object_cache'], - f"{_hash('.'.join(sorted(object_files)).encode()).hexdigest()}.so") + f"{_hash('.'.join(sorted(object_files + command_prefix)).encode()).hexdigest()}.so") if not exists(destination_file) or overwrite_destination_file: - command = [get_compiler_config()['command'], - _position_independent_flag, - *object_files, - *_tf_link_flags, - *_include_flags, - *additional_link_flags, - _shared_object_flag, - _output_flag, - destination_file] # /out: for msvc??? - + command = command_prefix + [destination_file] subprocess.check_call(command, env=_compile_env) return destination_file @@ -121,38 +119,38 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T if 'tensorflow_host_compiler' not in get_compiler_config(): get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] - destination_file = file + _object_file_extension if use_nvcc: - command = [nvcc, - '--expt-relaxed-constexpr', - '-ccbin', - get_compiler_config()['tensorflow_host_compiler'], - '-Xcompiler', - get_compiler_config()['flags'].replace('c++11', 'c++14'), - *_nvcc_flags, - file, - '-x', - 'cu', - '-Xcompiler', - _position_independent_flag, - _do_not_link_flag, - *_tf_compile_flags, - *_include_flags, - *additional_compile_flags, - _output_flag, - destination_file] + command_prefix = [nvcc, + '--expt-relaxed-constexpr', + '-ccbin', + get_compiler_config()['tensorflow_host_compiler'], + '-Xcompiler', + get_compiler_config()['flags'].replace('c++11', 'c++14'), + *_nvcc_flags, + file, + '-x', + 'cu', + '-Xcompiler', + _position_independent_flag, + _do_not_link_flag, + *_tf_compile_flags, + *_include_flags, + *additional_compile_flags, + _output_flag] else: - command = [get_compiler_config()['command'], - *(get_compiler_config()['flags']).split(' '), - file, - _do_not_link_flag, - *_tf_compile_flags, - *_include_flags, - *additional_compile_flags, - _output_flag, - destination_file] + command_prefix = [get_compiler_config()['command'], + *(get_compiler_config()['flags']).split(' '), + file, + _do_not_link_flag, + *_tf_compile_flags, + *_include_flags, + *additional_compile_flags, + _output_flag] + + destination_file = f'{file}_{_hash(".".join(command_prefix).encode()).hexdigest()}.{_object_file_extension}' if not exists(destination_file) or overwrite_destination_file: + command = command_prefix + [destination_file] subprocess.check_call(command, env=_compile_env) return destination_file @@ -209,11 +207,11 @@ def compile_sources_and_load(host_sources, file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}') write_file(file_name, source_code) - compile_file(file_name, - use_nvcc=is_cuda, - overwrite_destination_file=False, - additional_compile_flags=additional_compile_flags) - object_files.append(file_name + _object_file_extension) + object_file = compile_file(file_name, + use_nvcc=is_cuda, + overwrite_destination_file=False, + additional_compile_flags=additional_compile_flags) + object_files.append(object_file) print('Linking Tensorflow module...') module_file = link(object_files,