diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 278b05c803f95c538bb6f94b0ea729ee1e91413a..0e5bdd4c9155a6734a141a3ef0eaef0782622449 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -102,7 +102,11 @@ class TorchModule(JinjaCppFile): if not exists(file_name): write_file(file_name, source_code) # TODO: propagate extra headers - torch_extension = load(hash, [file_name], with_cuda=self.is_cuda) + torch_extension = load(hash, + [file_name], + with_cuda=self.is_cuda, + extra_include_paths=[ + get_pycuda_include_path(), get_pystencils_include_path()]) return torch_extension