From 74d93c0e4895c7837fe1b6470a17e9a09532396f Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 3 Jan 2020 13:30:49 +0100
Subject: [PATCH] Use pystencils compiler flags for Pybind11/Torch

---
 src/pystencils_autodiff/backends/astnodes.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index e51eb1b..62eb5e9 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -125,7 +125,7 @@ class TorchModule(JinjaCppFile):
         torch_extension = load(hash,
                                [file_name],
                                with_cuda=self.is_cuda,
-                               extra_cflags=['--std=c++14'],
+                               extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
                                extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
                                build_directory=build_dir,
                                extra_include_paths=[get_pycuda_include_path(),
@@ -198,9 +198,19 @@ setup_pybind11(cfg)
         if cache_dir not in sys.path:
             sys.path.append(cache_dir)
 
+        # Torch regards CXX
+        os.environ['CXX'] = get_compiler_config()['command']
+
         try:
             torch_extension = cppimport.imp(f'cppimport_{hash_str}')
         except Exception as e:
             print(e)
-            torch_extension = load(self.module_name, [file_name])
+            torch_extension = load(hash,
+                               [file_name],
+                               with_cuda=self.is_cuda,
+                               extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
+                               extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
+                               build_directory=cache_dir,
+                               extra_include_paths=[get_pycuda_include_path(),
+                                                    get_pystencils_include_path()])
         return torch_extension
-- 
GitLab