Skip to content
Snippets Groups Projects
Commit 9ed3556c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make openmp for tf cpu optional

parent 8f91bec3
No related branches found
No related tags found
No related merge requests found
...@@ -137,7 +137,12 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH: ...@@ -137,7 +137,12 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH:
_nvcc_flags.append('-use_fast_math') _nvcc_flags.append('-use_fast_math')
def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True, additional_compile_flags=[]): def compile_file(file,
use_nvcc=False,
nvcc='nvcc',
overwrite_destination_file=True,
additional_compile_flags=[],
openmp=True):
if 'tensorflow_host_compiler' not in get_compiler_config(): if 'tensorflow_host_compiler' not in get_compiler_config():
get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command']
import tensorflow as tf import tensorflow as tf
...@@ -155,7 +160,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -155,7 +160,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
'cu', 'cu',
'-Xcompiler', '-Xcompiler',
_position_independent_flag, _position_independent_flag,
_openmp_flag,
_do_not_link_flag, _do_not_link_flag,
*tf.sysconfig.get_compile_flags(), *tf.sysconfig.get_compile_flags(),
*_include_flags, *_include_flags,
...@@ -166,13 +170,14 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -166,13 +170,14 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
*(get_compiler_config()['flags']).split(' '), *(get_compiler_config()['flags']).split(' '),
file, file,
_do_not_link_flag, _do_not_link_flag,
_openmp_flag,
*tf.sysconfig.get_compile_flags(), *tf.sysconfig.get_compile_flags(),
*_include_flags, *_include_flags,
*additional_compile_flags, *additional_compile_flags,
_output_flag] _output_flag]
destination_file = f'{file}_{_hash(".".join(command_prefix).encode()).hexdigest()}{_object_file_extension}' destination_file = f'{file}_{_hash(".".join(command_prefix).encode()).hexdigest()}{_object_file_extension}'
if openmp:
command_prefix.append(_output_flag)
if not exists(destination_file) or overwrite_destination_file: if not exists(destination_file) or overwrite_destination_file:
command = command_prefix + [destination_file] command = command_prefix + [destination_file]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment