diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index e51eb1b3e6a58293557c2628a82c2021ab792031..62eb5e9ecdc1765a4258e7329416a9d3df4edefe 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -125,7 +125,7 @@ class TorchModule(JinjaCppFile):
         torch_extension = load(hash,
                                [file_name],
                                with_cuda=self.is_cuda,
-                               extra_cflags=['--std=c++14'],
+                               extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
                                extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
                                build_directory=build_dir,
                                extra_include_paths=[get_pycuda_include_path(),
@@ -198,9 +198,19 @@ setup_pybind11(cfg)
         if cache_dir not in sys.path:
             sys.path.append(cache_dir)
 
+        # Torch regards CXX
+        os.environ['CXX'] = get_compiler_config()['command']
+
         try:
             torch_extension = cppimport.imp(f'cppimport_{hash_str}')
         except Exception as e:
             print(e)
-            torch_extension = load(self.module_name, [file_name])
+            torch_extension = load(hash,
+                               [file_name],
+                               with_cuda=self.is_cuda,
+                               extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
+                               extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
+                               build_directory=cache_dir,
+                               extra_include_paths=[get_pycuda_include_path(),
+                                                    get_pystencils_include_path()])
         return torch_extension