diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 286b39cf75bef980b745eec3a98734f9c70c9e23..95627c5b0550bc506fba5305ca09805badac55de 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -119,6 +119,9 @@ class TorchModule(JinjaCppFile):
         if not exists(file_name):
             write_file(file_name, source_code)
 
+        # Torch regards CXX
+        os.environ['CXX'] = get_compiler_config()['command']
+
         torch_extension = load(hash,
                                [file_name],
                                with_cuda=self.is_cuda,