diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 1dc2594e6fd1cd2e359ee49af163d34aca23dcfd..35af9bef1c71ab27c88167487a07f03a3c400cb1 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -13,6 +13,7 @@ from collections.abc import Iterable from os.path import dirname, exists, join from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol +from pystencils.cache import cache_dir from pystencils.cpu.cpujit import get_cache_config from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils_autodiff._file_io import read_template_from_file, write_file @@ -106,9 +107,10 @@ class TorchModule(JinjaCppFile): torch_extension = load(hash, [file_name], with_cuda=self.is_cuda, - extra_cflags='--std=c++14', - extra_include_paths=[ - get_pycuda_include_path(), get_pystencils_include_path()]) + extra_cflags=['--std=c++14'], + build_directory=join(cache_dir, 'object_cache'), + extra_include_paths=[get_pycuda_include_path(), + get_pystencils_include_path()]) return torch_extension