Skip to content
Snippets Groups Projects
Commit 9c9abdb1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Change torch build folder to pystencils' cache_dir

parent 3fbfa910
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ from collections.abc import Iterable ...@@ -13,6 +13,7 @@ from collections.abc import Iterable
from os.path import dirname, exists, join from os.path import dirname, exists, join
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.cache import cache_dir
from pystencils.cpu.cpujit import get_cache_config from pystencils.cpu.cpujit import get_cache_config
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils_autodiff._file_io import read_template_from_file, write_file from pystencils_autodiff._file_io import read_template_from_file, write_file
...@@ -106,9 +107,10 @@ class TorchModule(JinjaCppFile): ...@@ -106,9 +107,10 @@ class TorchModule(JinjaCppFile):
torch_extension = load(hash, torch_extension = load(hash,
[file_name], [file_name],
with_cuda=self.is_cuda, with_cuda=self.is_cuda,
extra_cflags='--std=c++14', extra_cflags=['--std=c++14'],
extra_include_paths=[ build_directory=join(cache_dir, 'object_cache'),
get_pycuda_include_path(), get_pystencils_include_path()]) extra_include_paths=[get_pycuda_include_path(),
get_pystencils_include_path()])
return torch_extension return torch_extension
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment