diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 916912a9b775b469ca948322b0ed74296e3e7d24..89984cc3a28a03451906d46ab4cf2b1cf93186bd 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -14,9 +14,9 @@ import sysconfig from os.path import exists, join import p_tqdm + import pystencils from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path - from pystencils_autodiff._file_io import read_file, write_file _hash = hashlib.md5 @@ -35,6 +35,7 @@ if get_compiler_config()['os'] != 'windows': _position_independent_flag = "-fPIC" _compile_env = os.environ.copy() _object_file_extension = '.o' + _link_cudart = '-lcudart' else: _do_not_link_flag = '-c' _output_flag = '-o' @@ -75,7 +76,7 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a *_tf_link_flags, *_include_flags, *additional_link_flags, - '-lcudart', + _link_cudart, _shared_object_flag, _output_flag] if not destination_file: