From 8f91bec3ee03b32643089dbb9f08741fe3261dc1 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 29 Nov 2019 17:41:43 +0100 Subject: [PATCH] Add openmp flag to tensorflow jit --- src/pystencils_autodiff/tensorflow_jit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 1dd21e5..11c0d15 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -40,6 +40,7 @@ if get_compiler_config()['os'] != 'windows': _compile_env = os.environ.copy() _object_file_extension = '.o' _link_cudart = '-lcudart' + _openmp_flag = '-fopenmp' else: _do_not_link_flag = '-c' _output_flag = '-o' @@ -54,6 +55,7 @@ else: _compile_env.update(config_env) _object_file_extension = '.obj' _link_cudart = '/link cudart' # ??? + _openmp_flag = '/openmp' # ??? def link(object_files, @@ -153,6 +155,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T 'cu', '-Xcompiler', _position_independent_flag, + _openmp_flag, _do_not_link_flag, *tf.sysconfig.get_compile_flags(), *_include_flags, @@ -163,6 +166,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T *(get_compiler_config()['flags']).split(' '), file, _do_not_link_flag, + _openmp_flag, *tf.sysconfig.get_compile_flags(), *_include_flags, *additional_compile_flags, -- GitLab