diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 11c0d15394bae88ca39ae5589c7fe1495b6cab9a..9e58073a09b862dd81ac8acaee67120eb5410d95 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -137,7 +137,12 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH: _nvcc_flags.append('-use_fast_math') -def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=True, additional_compile_flags=[]): +def compile_file(file, + use_nvcc=False, + nvcc='nvcc', + overwrite_destination_file=True, + additional_compile_flags=[], + openmp=True): if 'tensorflow_host_compiler' not in get_compiler_config(): get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] import tensorflow as tf @@ -155,7 +160,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T 'cu', '-Xcompiler', _position_independent_flag, - _openmp_flag, _do_not_link_flag, *tf.sysconfig.get_compile_flags(), *_include_flags, @@ -166,13 +170,14 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T *(get_compiler_config()['flags']).split(' '), file, _do_not_link_flag, - _openmp_flag, *tf.sysconfig.get_compile_flags(), *_include_flags, *additional_compile_flags, _output_flag] destination_file = f'{file}_{_hash(".".join(command_prefix).encode()).hexdigest()}{_object_file_extension}' + if openmp: + command_prefix.append(_output_flag) if not exists(destination_file) or overwrite_destination_file: command = command_prefix + [destination_file]