Skip to content
Snippets Groups Projects
Commit 02bbb215 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Give every each torch operation its own build_directory

parent f4ae2f58
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
""" """
import os
import sys import sys
from collections.abc import Iterable from collections.abc import Iterable
from os.path import dirname, exists, join from os.path import dirname, exists, join
...@@ -102,12 +103,15 @@ class TorchModule(JinjaCppFile): ...@@ -102,12 +103,15 @@ class TorchModule(JinjaCppFile):
if not exists(file_name): if not exists(file_name):
write_file(file_name, source_code) write_file(file_name, source_code)
# TODO: propagate extra headers
build_dir = join(get_cache_config()['object_cache'], self.module_name)
os.makedirs(build_dir, exist_ok=True)
torch_extension = load(hash, torch_extension = load(hash,
[file_name], [file_name],
with_cuda=self.is_cuda, with_cuda=self.is_cuda,
extra_cflags=['--std=c++14'], extra_cflags=['--std=c++14'],
build_directory=get_cache_config()['object_cache'], build_directory=build_dir,
extra_include_paths=[get_pycuda_include_path(), extra_include_paths=[get_pycuda_include_path(),
get_pystencils_include_path()]) get_pystencils_include_path()])
return torch_extension return torch_extension
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment