diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index db753059432a78a0aaff49b14d3bc96cddd95dfe..e35d1454d537eeb129be7a57f2ab02c952d95cfc 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -137,12 +137,13 @@ class TorchModule(JinjaCppFile): return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), function_name='call_' + kernel_ast.function_name) - def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None): + def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None): from torch.utils.cpp_extension import load file_extension = '.cu' if self.is_cuda else '.cpp' source_code = str(self) hash = _hash(source_code.encode()).hexdigest() - build_dir = join(get_cache_config()['object_cache'], self.module_name) + if not build_dir: + build_dir = join(get_cache_config()['object_cache'], self.module_name) os.makedirs(build_dir, exist_ok=True) file_name = join(build_dir, f'{hash}{file_extension}')