Skip to content
Snippets Groups Projects
Commit afd19c8b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add pycuda/pystencils include paths to pytorch native compilation

parent a31c13d8
No related branches found
No related tags found
No related merge requests found
......@@ -102,7 +102,11 @@ class TorchModule(JinjaCppFile):
if not exists(file_name):
write_file(file_name, source_code)
# TODO: propagate extra headers
torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
torch_extension = load(hash,
[file_name],
with_cuda=self.is_cuda,
extra_include_paths=[
get_pycuda_include_path(), get_pystencils_include_path()])
return torch_extension
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment