diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index db753059432a78a0aaff49b14d3bc96cddd95dfe..e35d1454d537eeb129be7a57f2ab02c952d95cfc 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -137,12 +137,13 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
-    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
+    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None):
         from torch.utils.cpp_extension import load
         file_extension = '.cu' if self.is_cuda else '.cpp'
         source_code = str(self)
         hash = _hash(source_code.encode()).hexdigest()
-        build_dir = join(get_cache_config()['object_cache'], self.module_name)
+        if not build_dir:
+            build_dir = join(get_cache_config()['object_cache'], self.module_name)
         os.makedirs(build_dir, exist_ok=True)
         file_name = join(build_dir, f'{hash}{file_extension}')