diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py
index bec0346100800f772774dc2d2753438ae81b3572..fdb35553d3072de3d96fdc22bf6c9b0220a77ac8 100644
--- a/src/pystencils_autodiff/tensorflow_jit.py
+++ b/src/pystencils_autodiff/tensorflow_jit.py
@@ -50,7 +50,11 @@ else:
     _link_cudart = '/link cudart'  # ???
 
 
-def link(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
+def link(object_files,
+         destination_file=None,
+         overwrite_destination_file=True,
+         additional_link_flags=[],
+         link_cudart=True):
     """Compiles given :param:`source_file` to a Tensorflow shared Library.
 
     .. warning::
@@ -69,7 +73,6 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a
                       *tf.sysconfig.get_link_flags(),
                       *_include_flags,
                       *additional_link_flags,
-                      _link_cudart,
                       _shared_object_flag,
                       _output_flag]
     if not destination_file:
@@ -78,15 +81,25 @@ def link(object_files, destination_file=None, overwrite_destination_file=True, a
 
     if not exists(destination_file) or overwrite_destination_file:
         command = command_prefix + [destination_file]
+        if link_cudart:
+            command.append(_link_cudart)
         subprocess.check_call(command, env=_compile_env)
 
     return destination_file
 
 
-def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
+def link_and_load(object_files,
+                  destination_file=None,
+                  overwrite_destination_file=True,
+                  additional_link_flags=[],
+                  link_cudart=True):
     import tensorflow as tf
 
-    destination_file = link(object_files, destination_file, overwrite_destination_file, additional_link_flags)
+    destination_file = link(object_files,
+                            destination_file,
+                            overwrite_destination_file,
+                            additional_link_flags,
+                            link_cudart)
     lib = tf.load_op_library(destination_file)
     return lib
 
@@ -197,7 +210,8 @@ def compile_sources_and_load(host_sources,
     print('Linking Tensorflow module...')
     module_file = link(object_files,
                        overwrite_destination_file=False,
-                       additional_link_flags=additional_link_flags)
+                       additional_link_flags=additional_link_flags,
+                       link_cudart=cuda_sources)
     if not compile_only:
         module = tf.load_op_library(module_file)
         if module: