diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 1dd21e587bc275d84e5d8d86f9d67d7bde80f293..11c0d15394bae88ca39ae5589c7fe1495b6cab9a 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -40,6 +40,7 @@ if get_compiler_config()['os'] != 'windows': _compile_env = os.environ.copy() _object_file_extension = '.o' _link_cudart = '-lcudart' + _openmp_flag = '-fopenmp' else: _do_not_link_flag = '-c' _output_flag = '-o' @@ -54,6 +55,7 @@ else: _compile_env.update(config_env) _object_file_extension = '.obj' _link_cudart = '/link cudart' # ??? + _openmp_flag = '/openmp' # ??? def link(object_files, @@ -153,6 +155,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T 'cu', '-Xcompiler', _position_independent_flag, + _openmp_flag, _do_not_link_flag, *tf.sysconfig.get_compile_flags(), *_include_flags, @@ -163,6 +166,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T *(get_compiler_config()['flags']).split(' '), file, _do_not_link_flag, + _openmp_flag, *tf.sysconfig.get_compile_flags(), *_include_flags, *additional_compile_flags,