Skip to content
Snippets Groups Projects
Commit 7c04a372 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use compile_env to get right environment for msvc

parent 6c0f5cda
Branches
Tags
No related merge requests found
Pipeline #17978 failed
......@@ -29,6 +29,7 @@ if get_compiler_config()['os'] != 'windows':
_include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()]
_do_not_link_flag = "-c"
_position_independent_flag = "-fPIC"
_compile_environment = os.environ.copy()
else:
_do_not_link_flag = "/c"
_output_flag = '/OUT:'
......@@ -36,6 +37,9 @@ else:
_include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()]
_position_independent_flag = "/DTHIS_FLAG_DOES_NOTHING"
get_compiler_config()['command'] = 'cl.exe'
config_env = get_compiler_config()['env'] if 'env' in get_compiler_config() else {}
_compile_environment = os.environ.copy()
_compile_environment.update(config_env)
try:
......@@ -75,7 +79,7 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a
_output_flag,
destination_file] # /out: for msvc???
subprocess.check_call(command)
subprocess.check_call(command, env=_compile_environment)
return destination_file
......@@ -147,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
destination_file]
if not exists(destination_file) or overwrite_destination_file:
subprocess.check_call(command)
subprocess.check_call(command, env=_compile_environment)
return destination_file
......
......@@ -23,6 +23,8 @@ from pystencils_autodiff.backends.astnodes import TensorflowModule
import pystencils
from pystencils.cpu.cpujit import get_compiler_config
from pystencils.include import get_pystencils_include_path
from pystencils_autodiff.tensorflow_jit import _compile_env
def test_detect_cpu_vs_cpu():
......@@ -78,7 +80,7 @@ def test_native_tensorflow_compilation_cpu():
command = ['c++', '-fPIC', temp_file.name, '-O2', '-shared',
'-o', 'foo.so'] + compile_flags + link_flags + extra_flags
print(command)
subprocess.check_call(command)
subprocess.check_call(command, env=_compile_env)
lib = tf.load_op_library(join(os.getcwd(), 'foo.so'))
assert 'call_forward' in dir(lib)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment