Skip to content
Snippets Groups Projects
Commit 7c04a372 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use compile_env to get right environment for msvc

parent 6c0f5cda
Branches
Tags
No related merge requests found
Pipeline #17978 failed
...@@ -29,6 +29,7 @@ if get_compiler_config()['os'] != 'windows': ...@@ -29,6 +29,7 @@ if get_compiler_config()['os'] != 'windows':
_include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()] _include_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()]
_do_not_link_flag = "-c" _do_not_link_flag = "-c"
_position_independent_flag = "-fPIC" _position_independent_flag = "-fPIC"
_compile_environment = os.environ.copy()
else: else:
_do_not_link_flag = "/c" _do_not_link_flag = "/c"
_output_flag = '/OUT:' _output_flag = '/OUT:'
...@@ -36,6 +37,9 @@ else: ...@@ -36,6 +37,9 @@ else:
_include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()] _include_flags = ['/I' + sysconfig.get_paths()['include'], '/I' + get_pystencils_include_path()]
_position_independent_flag = "/DTHIS_FLAG_DOES_NOTHING" _position_independent_flag = "/DTHIS_FLAG_DOES_NOTHING"
get_compiler_config()['command'] = 'cl.exe' get_compiler_config()['command'] = 'cl.exe'
config_env = get_compiler_config()['env'] if 'env' in get_compiler_config() else {}
_compile_environment = os.environ.copy()
_compile_environment.update(config_env)
try: try:
...@@ -75,7 +79,7 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a ...@@ -75,7 +79,7 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a
_output_flag, _output_flag,
destination_file] # /out: for msvc??? destination_file] # /out: for msvc???
subprocess.check_call(command) subprocess.check_call(command, env=_compile_environment)
return destination_file return destination_file
...@@ -147,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -147,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
destination_file] destination_file]
if not exists(destination_file) or overwrite_destination_file: if not exists(destination_file) or overwrite_destination_file:
subprocess.check_call(command) subprocess.check_call(command, env=_compile_environment)
return destination_file return destination_file
......
...@@ -23,6 +23,8 @@ from pystencils_autodiff.backends.astnodes import TensorflowModule ...@@ -23,6 +23,8 @@ from pystencils_autodiff.backends.astnodes import TensorflowModule
import pystencils import pystencils
from pystencils.cpu.cpujit import get_compiler_config from pystencils.cpu.cpujit import get_compiler_config
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils_autodiff.tensorflow_jit import _compile_env
def test_detect_cpu_vs_cpu(): def test_detect_cpu_vs_cpu():
...@@ -78,7 +80,7 @@ def test_native_tensorflow_compilation_cpu(): ...@@ -78,7 +80,7 @@ def test_native_tensorflow_compilation_cpu():
command = ['c++', '-fPIC', temp_file.name, '-O2', '-shared', command = ['c++', '-fPIC', temp_file.name, '-O2', '-shared',
'-o', 'foo.so'] + compile_flags + link_flags + extra_flags '-o', 'foo.so'] + compile_flags + link_flags + extra_flags
print(command) print(command)
subprocess.check_call(command) subprocess.check_call(command, env=_compile_env)
lib = tf.load_op_library(join(os.getcwd(), 'foo.so')) lib = tf.load_op_library(join(os.getcwd(), 'foo.so'))
assert 'call_forward' in dir(lib) assert 'call_forward' in dir(lib)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment