From afd19c8b5b9c344f5188576f4e4c100cb9bf3034 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 14 Oct 2019 17:23:23 +0200 Subject: [PATCH] Add pycuda/pystencils include paths to pytorch native compilation --- src/pystencils_autodiff/backends/astnodes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 278b05c..0e5bdd4 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -102,7 +102,11 @@ class TorchModule(JinjaCppFile): if not exists(file_name): write_file(file_name, source_code) # TODO: propagate extra headers - torch_extension = load(hash, [file_name], with_cuda=self.is_cuda) + torch_extension = load(hash, + [file_name], + with_cuda=self.is_cuda, + extra_include_paths=[ + get_pycuda_include_path(), get_pystencils_include_path()]) return torch_extension -- GitLab