diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 909b980f43821ed70ccc21193a9210d70b5c4f4d..3b90107e145984eb8d1dbeab2aeca9706ae92080 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -14,13 +14,17 @@ import sysconfig from os.path import exists, join import p_tqdm - import pystencils from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path + from pystencils_autodiff._file_io import read_file, write_file _hash = hashlib.md5 +if 'NVCC_BINARY' in os.environ: + NVCC_BINARY = os.environ['NVCC_BINARY'] +else: + NVCC_BINARY = 'nvcc' # TODO: msvc if get_compiler_config()['os'] != 'windows': @@ -120,7 +124,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T get_compiler_config()['tensorflow_host_compiler'] = get_compiler_config()['command'] if use_nvcc: - command_prefix = [nvcc, + command_prefix = [NVCC_BINARY, '--expt-relaxed-constexpr', '-ccbin', get_compiler_config()['tensorflow_host_compiler'],