diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index bec0346100800f772774dc2d2753438ae81b3572..fdb35553d3072de3d96fdc22bf6c9b0220a77ac8 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -50,7 +50,11 @@ else: _link_cudart = '/link cudart' # ??? -def link(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]): +def link(object_files, + destination_file=None, + overwrite_destination_file=True, + additional_link_flags=[], + link_cudart=True): """Compiles given :param:`source_file` to a Tensorflow shared Library. .. warning:: @@ -69,7 +73,6 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a *tf.sysconfig.get_link_flags(), *_include_flags, *additional_link_flags, - _link_cudart, _shared_object_flag, _output_flag] if not destination_file: @@ -78,15 +81,25 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a if not exists(destination_file) or overwrite_destination_file: command = command_prefix + [destination_file] + if link_cudart: + command.append(_link_cudart) subprocess.check_call(command, env=_compile_env) return destination_file -def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]): +def link_and_load(object_files, + destination_file=None, + overwrite_destination_file=True, + additional_link_flags=[], + link_cudart=True): import tensorflow as tf - destination_file = link(object_files, destination_file, overwrite_destination_file, additional_link_flags) + destination_file = link(object_files, + destination_file, + overwrite_destination_file, + additional_link_flags, + link_cudart) lib = tf.load_op_library(destination_file) return lib @@ -197,7 +210,8 @@ def compile_sources_and_load(host_sources, print('Linking Tensorflow module...') module_file = link(object_files, overwrite_destination_file=False, - additional_link_flags=additional_link_flags) + additional_link_flags=additional_link_flags, + link_cudart=cuda_sources) if not compile_only: module = tf.load_op_library(module_file) if module: