Skip to content
Snippets Groups Projects
Commit 245bf300 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add build_dir option to TorchModule.compile

parent 26df2dfb
Branches
Tags
No related merge requests found
......@@ -137,12 +137,13 @@ class TorchModule(JinjaCppFile):
return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name)
def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None):
from torch.utils.cpp_extension import load
file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self)
hash = _hash(source_code.encode()).hexdigest()
build_dir = join(get_cache_config()['object_cache'], self.module_name)
if not build_dir:
build_dir = join(get_cache_config()['object_cache'], self.module_name)
os.makedirs(build_dir, exist_ok=True)
file_name = join(build_dir, f'{hash}{file_extension}')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment