Skip to content
Snippets Groups Projects
Commit 9592e961 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix caching for PybindModules

parent 70aee528
Branches
Tags
No related merge requests found
Pipeline #19508 failed
......@@ -94,8 +94,9 @@ class TorchModule(JinjaCppFile):
super().__init__(ast_dict)
def generate_wrapper_function(self, kernel_ast):
return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
@classmethod
def generate_wrapper_function(cls, kernel_ast):
return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name)
def compile(self):
......@@ -165,10 +166,12 @@ setup_pybind11(cfg)
assert not self.is_cuda
cache_dir = get_cache_config()['object_cache']
source_code = self.CPP_IMPORT_PREFIX + str(self)
file_name = join(cache_dir, f'{self.module_name}.cpp')
hash_str = _hash(source_code.encode()).hexdigest()
cache_dir = join(get_cache_config()['object_cache'], f'cppimport_{hash_str}')
file_name = join(cache_dir, f'{self.module_name}.cpp')
os.makedirs(cache_dir, exist_ok=True)
if not exists(file_name):
write_file(file_name, source_code)
# TODO: propagate extra headers
......@@ -176,7 +179,8 @@ setup_pybind11(cfg)
sys.path.append(cache_dir)
try:
torch_extension = cppimport.imp(f'{self.module_name}')
except Exception:
torch_extension = cppimport.imp(self.module_name)
except Exception as e:
print(e)
torch_extension = load(self.module_name, [file_name])
return torch_extension
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment