diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 9e58073a09b862dd81ac8acaee67120eb5410d95..97f10e020c051f9bd2b9bdf4d4b279fc15584baa 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -139,7 +139,7 @@ if pystencils.gpucuda.cudajit.USE_FAST_MATH: def compile_file(file, use_nvcc=False, - nvcc='nvcc', + nvcc=None, overwrite_destination_file=True, additional_compile_flags=[], openmp=True): @@ -148,7 +148,7 @@ def compile_file(file, import tensorflow as tf if use_nvcc: - command_prefix = [NVCC_BINARY, + command_prefix = [nvcc or NVCC_BINARY, '--expt-relaxed-constexpr', '-ccbin', get_compiler_config()['tensorflow_host_compiler'],