diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 1d3d6a2b46ef9eb2e5c78448056f5c36884f8d10..bd6c78131af9833e6deff4c645130aadd80d152c 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -8,16 +8,19 @@
 
 """
 
+import os
 from collections.abc import Iterable
-from os.path import dirname, join
+from os.path import dirname, exists, join
 
+import pystencils
 from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
-from pystencils_autodiff._file_io import read_template_from_file
+from pystencils_autodiff._file_io import read_template_from_file, write_file
 from pystencils_autodiff.backends.python_bindings import (
     PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
     TensorflowPythonBindings, TorchPythonBindings)
 from pystencils_autodiff.framework_integration.astnodes import (
     DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
+from pystencils_autodiff.tensorflow_jit import _hash
 
 
 class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
@@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile):
         if not isinstance(kernel_asts, Iterable):
             kernel_asts = [kernel_asts]
         wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
+        self.module_name = module_name
 
         ast_dict = {
             'kernels': kernel_asts,
@@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
+    def compile(self):
+        from torch.utils.cpp_extension import load
+        file_extension = '.cu' if self.is_cuda else '.cpp'
+        source_code = str(self)
+        hash = _hash(source_code.encode()).hexdigest()
+        try:
+            os.mkdir(join(pystencils.cache.cache_dir, hash))
+        except Exception:
+            pass
+        file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}')
+
+        if not exists(file_name):
+            write_file(file_name, source_code)
+        torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
+        return torch_extension
+
 
 class TensorflowModule(TorchModule):
     DESTRUCTURING_CLASS = TensorflowTensorDestructuring
diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index 4e88075a585bae7c619c3407f4068d665a30ac92..40a5e23f16c817f3314865ea9a48f453b2ec6a5c 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu():
     assert 'call_forward' in dir(torch_extension)
     assert 'call_backward' in dir(torch_extension)
 
+    torch_extension = module.compile()
+    assert torch_extension is not None
+    assert 'call_forward' in dir(torch_extension)
+    assert 'call_backward' in dir(torch_extension)
+
 
 @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
 def test_torch_native_compilation_gpu():
@@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu():
     assert 'call_forward' in dir(torch_extension)
     assert 'call_backward' in dir(torch_extension)
 
+    torch_extension = module.compile()
+    assert torch_extension is not None
+    assert 'call_forward' in dir(torch_extension)
+    assert 'call_backward' in dir(torch_extension)
+
 
 @pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
 def test_execute_torch_gpu():