Skip to content
Snippets Groups Projects
Commit 3bcda22b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Enable additional_compile_flags, additional_link_flags for tensorflow_jit

parent 0cccbe5e
No related branches found
No related tags found
No related merge requests found
Pipeline #17930 failed
...@@ -42,7 +42,7 @@ except ImportError: ...@@ -42,7 +42,7 @@ except ImportError:
pass pass
def link_and_load(object_files, destination_file=None, overwrite_destination_file=True): def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
"""Compiles given :param:`source_file` to a Tensorflow shared Library. """Compiles given :param:`source_file` to a Tensorflow shared Library.
.. warning:: .. warning::
...@@ -64,6 +64,7 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil ...@@ -64,6 +64,7 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil
*object_files, *object_files,
*_tf_link_flags, *_tf_link_flags,
*_include_flags, *_include_flags,
*additional_link_flags,
_shared_object_flag, _shared_object_flag,
_output_flag, _output_flag,
destination_file] # /out: for msvc??? destination_file] # /out: for msvc???
...@@ -97,7 +98,7 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH: ...@@ -97,7 +98,7 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH:
_nvcc_flags.append('-use_fast_math') _nvcc_flags.append('-use_fast_math')
def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True): def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True, additional_compile_flags=[]):
if 'tensorflow_host_compiler' not in get_compiler_config(): if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
write_file(pystencils.cpu.cpujit.get_configuration_file_path(), json.dumps(pystencils.cpu.cpujit._config)) write_file(pystencils.cpu.cpujit.get_configuration_file_path(), json.dumps(pystencils.cpu.cpujit._config))
...@@ -119,6 +120,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -119,6 +120,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
_do_not_link_flag, _do_not_link_flag,
*_tf_compile_flags, *_tf_compile_flags,
*_include_flags, *_include_flags,
*additional_compile_flags,
_output_flag, _output_flag,
destination_file] destination_file]
else: else:
...@@ -128,6 +130,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -128,6 +130,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
_do_not_link_flag, _do_not_link_flag,
*_tf_compile_flags, *_tf_compile_flags,
*_include_flags, *_include_flags,
*additional_compile_flags,
_output_flag, _output_flag,
destination_file] destination_file]
...@@ -136,7 +139,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -136,7 +139,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
return destination_file return destination_file
def compile_sources_and_load(host_sources, cuda_sources=[]): def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
object_files = [] object_files = []
...@@ -152,11 +155,12 @@ def compile_sources_and_load(host_sources, cuda_sources=[]): ...@@ -152,11 +155,12 @@ def compile_sources_and_load(host_sources, cuda_sources=[]):
file_name = join(pystencils.cache.cache_dir, f'{abs(hash(source_code)):x}{file_extension}') file_name = join(pystencils.cache.cache_dir, f'{abs(hash(source_code)):x}{file_extension}')
write_file(file_name, source_code) write_file(file_name, source_code)
compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False) compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False,
additional_compile_flags=additional_compile_flags)
object_files.append(file_name + '.o') object_files.append(file_name + '.o')
print('Linking Tensorflow module...') print('Linking Tensorflow module...')
module = link_and_load(object_files, overwrite_destination_file=False) module = link_and_load(object_files, overwrite_destination_file=False, additional_link_flags=additional_link_flags)
if module: if module:
print('Loaded Tensorflow module.') print('Loaded Tensorflow module.')
return module return module
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment