From 7c04a372846b89871e8a934c6f8ccd23643d9d78 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 13 Sep 2019 01:41:00 +0200 Subject: [PATCH] Use compile_env to get right environment for msvc --- src/pystencils_autodiff/tensorflow_jit.py | 8 ++++++-- tests/test_native_tensorflow_compilation.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index f0257d6..56c4772 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -29,6 +29,7 @@ if get_compiler_config()['os'] != 'windows': _include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()] _do_not_link_flag = "-c" _position_independent_flag = "-fPIC" + _compile_environment = os.environ.copy() else: _do_not_link_flag = "/c" _output_flag = '/OUT:' @@ -36,6 +37,9 @@ else: _include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()] _position_independent_flag = "/DTHIS_FLAG_DOES_NOTHING" get_compiler_config()['command'] = 'cl.exe' + config_env = get_compiler_config()['env'] if 'env' in get_compiler_config() else {} + _compile_environment = os.environ.copy() + _compile_environment.update(config_env) try: @@ -75,7 +79,7 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a _output_flag, destination_file] # /out: for msvc??? - subprocess.check_call(command) + subprocess.check_call(command, env=_compile_environment) return destination_file @@ -147,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T destination_file] if not exists(destination_file) or overwrite_destination_file: - subprocess.check_call(command) + subprocess.check_call(command, env=_compile_environment) return destination_file diff --git a/tests/test_native_tensorflow_compilation.py b/tests/test_native_tensorflow_compilation.py index 4cf8425..cc64e11 100644 --- a/tests/test_native_tensorflow_compilation.py +++ b/tests/test_native_tensorflow_compilation.py @@ -23,6 +23,8 @@ from pystencils_autodiff.backends.astnodes import TensorflowModule import pystencils from pystencils.cpu.cpujit import get_compiler_config from pystencils.include import get_pystencils_include_path +from pystencils_autodiff.tensorflow_jit import _compile_env + def test_detect_cpu_vs_cpu(): @@ -78,7 +80,7 @@ def test_native_tensorflow_compilation_cpu(): command = ['c++', '-fPIC', temp_file.name, '-O2', '-shared', '-o', 'foo.so'] + compile_flags + link_flags + extra_flags print(command) - subprocess.check_call(command) + subprocess.check_call(command, env=_compile_env) lib = tf.load_op_library(join(os.getcwd(), 'foo.so')) assert 'call_forward' in dir(lib) -- GitLab