From afd19c8b5b9c344f5188576f4e4c100cb9bf3034 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 14 Oct 2019 17:23:23 +0200
Subject: [PATCH] Add pycuda/pystencils include paths to pytorch native
 compilation

---
 src/pystencils_autodiff/backends/astnodes.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 278b05c..0e5bdd4 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -102,7 +102,11 @@ class TorchModule(JinjaCppFile):
         if not exists(file_name):
             write_file(file_name, source_code)
         # TODO: propagate extra headers
-        torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
+        torch_extension = load(hash,
+                               [file_name],
+                               with_cuda=self.is_cuda,
+                               extra_include_paths=[
+                                   get_pycuda_include_path(), get_pystencils_include_path()])
         return torch_extension
 
 
-- 
GitLab