diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index be3fbf98ebe2c5eea30f850a59415e8d1f6c1b07..d0af96eaec2c022b32688ef31e44eb061608ab37 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -42,7 +42,7 @@ except ImportError: pass -def link_and_load(object_files, destination_file=None, overwrite_destination_file=True): +def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]): """Compiles given :param:`source_file` to a Tensorflow shared Library. .. warning:: @@ -64,6 +64,7 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil *object_files, *_tf_link_flags, *_include_flags, + *additional_link_flags, _shared_object_flag, _output_flag, destination_file] # /out: for msvc??? @@ -97,7 +98,7 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH: _nvcc_flags.append('-use_fast_math') -def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True): +def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True, additional_compile_flags=[]): if 'tensorflow_host_compiler' not in get_compiler_config(): get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] write_file(pystencils.cpu.cpujit.get_configuration_file_path(), json.dumps(pystencils.cpu.cpujit._config)) @@ -119,6 +120,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T _do_not_link_flag, *_tf_compile_flags, *_include_flags, + *additional_compile_flags, _output_flag, destination_file] else: @@ -128,6 +130,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T _do_not_link_flag, *_tf_compile_flags, *_include_flags, + *additional_compile_flags, _output_flag, destination_file] @@ -136,7 +139,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T return destination_file -def compile_sources_and_load(host_sources, cuda_sources=[]): +def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]): object_files = [] @@ -152,11 +155,12 @@ def compile_sources_and_load(host_sources, cuda_sources=[]): file_name = join(pystencils.cache.cache_dir, f'{abs(hash(source_code)):x}{file_extension}') write_file(file_name, source_code) - compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False) + compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False, + additional_compile_flags=additional_compile_flags) object_files.append(file_name + '.o') print('Linking Tensorflow module...') - module = link_and_load(object_files, overwrite_destination_file=False) + module = link_and_load(object_files, overwrite_destination_file=False, additional_link_flags=additional_link_flags) if module: print('Loaded Tensorflow module.') return module