diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 278b05c803f95c538bb6f94b0ea729ee1e91413a..0e5bdd4c9155a6734a141a3ef0eaef0782622449 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -102,7 +102,11 @@ class TorchModule(JinjaCppFile):
         if not exists(file_name):
             write_file(file_name, source_code)
         # TODO: propagate extra headers
-        torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
+        torch_extension = load(hash,
+                               [file_name],
+                               with_cuda=self.is_cuda,
+                               extra_include_paths=[
+                                   get_pycuda_include_path(), get_pystencils_include_path()])
         return torch_extension