Skip to content
Snippets Groups Projects
Commit 3bcda22b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Enable additional_compile_flags, additional_link_flags for tensorflow_jit

parent 0cccbe5e
Branches
Tags
No related merge requests found
Pipeline #17930 failed
......@@ -42,7 +42,7 @@ except ImportError:
pass
def link_and_load(object_files, destination_file=None, overwrite_destination_file=True):
def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
"""Compiles given :param:`source_file` to a Tensorflow shared Library.
.. warning::
......@@ -64,6 +64,7 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil
*object_files,
*_tf_link_flags,
*_include_flags,
*additional_link_flags,
_shared_object_flag,
_output_flag,
destination_file] # /out: for msvc???
......@@ -97,7 +98,7 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH:
_nvcc_flags.append('-use_fast_math')
def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True):
def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True, additional_compile_flags=[]):
if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
write_file(pystencils.cpu.cpujit.get_configuration_file_path(), json.dumps(pystencils.cpu.cpujit._config))
......@@ -119,6 +120,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
_do_not_link_flag,
*_tf_compile_flags,
*_include_flags,
*additional_compile_flags,
_output_flag,
destination_file]
else:
......@@ -128,6 +130,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
_do_not_link_flag,
*_tf_compile_flags,
*_include_flags,
*additional_compile_flags,
_output_flag,
destination_file]
......@@ -136,7 +139,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
return destination_file
def compile_sources_and_load(host_sources, cuda_sources=[]):
def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
object_files = []
......@@ -152,11 +155,12 @@ def compile_sources_and_load(host_sources, cuda_sources=[]):
file_name = join(pystencils.cache.cache_dir, f'{abs(hash(source_code)):x}{file_extension}')
write_file(file_name, source_code)
compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False)
compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False,
additional_compile_flags=additional_compile_flags)
object_files.append(file_name + '.o')
print('Linking Tensorflow module...')
module = link_and_load(object_files, overwrite_destination_file=False)
module = link_and_load(object_files, overwrite_destination_file=False, additional_link_flags=additional_link_flags)
if module:
print('Loaded Tensorflow module.')
return module
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment