diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py
index 6cd95960945d0000f297a8d8d75ef521f6e3e745..1384893b0f9c2c06b5f37142c0f36a02885dccf2 100644
--- a/src/pystencils_autodiff/tensorflow_jit.py
+++ b/src/pystencils_autodiff/tensorflow_jit.py
@@ -8,6 +8,7 @@
 
 """
 import hashlib
+import os
 import subprocess
 import sysconfig
 from os.path import exists, join
@@ -45,7 +46,7 @@ except ImportError:
     pass
 
 
-def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
+def link(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
     """Compiles given :param:`source_file` to a Tensorflow shared Library.
 
     .. warning::
@@ -75,11 +76,18 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil
 
         subprocess.check_call(command)
 
+    return destination_file
+
+
+def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]):
+    destination_file = link(object_files, destination_file, overwrite_destination_file, additional_link_flags)
     lib = tf.load_op_library(destination_file)
     return lib
 
 
 def try_get_cuda_arch_flag():
+    if 'PYSTENCILS_TENSORFLOW_NVCC_ARCH' in os.environ:
+        return "-arch=sm_" + os.environ['PYSTENCILS_TENSORFLOW_NVCC_ARCH']
     try:
         from pycuda.driver import Context
         arch = "sm_%d%d" % Context.get_device().compute_capability()
@@ -143,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
     return destination_file
 
 
-def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
+def compile_sources(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
 
     object_files = []
 
@@ -172,3 +180,42 @@ def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_f
     if module:
         print('Loaded Tensorflow module.')
     return module
+
+
+def compile_sources_and_load(host_sources,
+                             cuda_sources=[],
+                             additional_compile_flags=[],
+                             additional_link_flags=[],
+                             compile_only=False):
+
+    object_files = []
+
+    for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'):
+        is_cuda = source in cuda_sources
+
+        if exists(source):
+            source_code = read_file(source)
+        else:
+            source_code = source
+
+        file_extension = '.cu' if is_cuda else '.cpp'
+        file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}')
+        write_file(file_name, source_code)
+
+        compile_file(file_name,
+                     use_nvcc=is_cuda,
+                     overwrite_destination_file=False,
+                     additional_compile_flags=additional_compile_flags)
+        object_files.append(file_name + '.o')
+
+    print('Linking Tensorflow module...')
+    module_file = link(object_files,
+                       overwrite_destination_file=False,
+                       additional_link_flags=additional_link_flags)
+    if not compile_only:
+        module = tf.load_op_library(module_file)
+        if module:
+            print('Loaded Tensorflow module.')
+        return module
+    else:
+        return module_file
diff --git a/tests/test_tensorflow_jit.py b/tests/test_tensorflow_jit.py
index 187aa96a925aa113ed7cec70286aa993836036ea..bea327e7c1cb0b4caefdb55bfe4951977200e308 100644
--- a/tests/test_tensorflow_jit.py
+++ b/tests/test_tensorflow_jit.py
@@ -8,6 +8,8 @@
 
 """
 
+from os.path import exists
+
 import pytest
 import sympy
 
@@ -44,6 +46,10 @@ def test_tensorflow_jit_gpu():
     assert 'call_forward_jit_gpu' in dir(lib)
     assert 'call_backward_jit_gpu' in dir(lib)
 
+    file_name = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([], [str(module)], compile_only=True)
+    print(file_name)
+    assert exists(file_name)
+
 
 def test_tensorflow_jit_cpu():