From 245bf3008f31f58bd6fcb818230e20a2c66fc137 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 24 Feb 2020 16:19:09 +0100 Subject: [PATCH] Add build_dir option to TorchModule.compile --- src/pystencils_autodiff/backends/astnodes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index db75305..e35d145 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -137,12 +137,13 @@ class TorchModule(JinjaCppFile): return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), function_name='call_' + kernel_ast.function_name) - def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None): + def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None): from torch.utils.cpp_extension import load file_extension = '.cu' if self.is_cuda else '.cpp' source_code = str(self) hash = _hash(source_code.encode()).hexdigest() - build_dir = join(get_cache_config()['object_cache'], self.module_name) + if not build_dir: + build_dir = join(get_cache_config()['object_cache'], self.module_name) os.makedirs(build_dir, exist_ok=True) file_name = join(build_dir, f'{hash}{file_extension}') -- GitLab