diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 1dc2594e6fd1cd2e359ee49af163d34aca23dcfd..35af9bef1c71ab27c88167487a07f03a3c400cb1 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -13,6 +13,7 @@ from collections.abc import Iterable
 from os.path import dirname, exists, join
 
 from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
+from pystencils.cache import cache_dir
 from pystencils.cpu.cpujit import get_cache_config
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
 from pystencils_autodiff._file_io import read_template_from_file, write_file
@@ -106,9 +107,10 @@ class TorchModule(JinjaCppFile):
         torch_extension = load(hash,
                                [file_name],
                                with_cuda=self.is_cuda,
-                               extra_cflags='--std=c++14',
-                               extra_include_paths=[
-                                   get_pycuda_include_path(), get_pystencils_include_path()])
+                               extra_cflags=['--std=c++14'],
+                               build_directory=join(cache_dir, 'object_cache'),
+                               extra_include_paths=[get_pycuda_include_path(),
+                                                    get_pystencils_include_path()])
         return torch_extension