From 245bf3008f31f58bd6fcb818230e20a2c66fc137 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 24 Feb 2020 16:19:09 +0100
Subject: [PATCH] Add build_dir option to TorchModule.compile

---
 src/pystencils_autodiff/backends/astnodes.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index db75305..e35d145 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -137,12 +137,13 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
-    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
+    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None):
         from torch.utils.cpp_extension import load
         file_extension = '.cu' if self.is_cuda else '.cpp'
         source_code = str(self)
         hash = _hash(source_code.encode()).hexdigest()
-        build_dir = join(get_cache_config()['object_cache'], self.module_name)
+        if not build_dir:
+            build_dir = join(get_cache_config()['object_cache'], self.module_name)
         os.makedirs(build_dir, exist_ok=True)
         file_name = join(build_dir, f'{hash}{file_extension}')
 
-- 
GitLab