diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 286b39cf75bef980b745eec3a98734f9c70c9e23..95627c5b0550bc506fba5305ca09805badac55de 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -119,6 +119,9 @@ class TorchModule(JinjaCppFile): if not exists(file_name): write_file(file_name, source_code) + # Torch regards CXX + os.environ['CXX'] = get_compiler_config()['command'] + torch_extension = load(hash, [file_name], with_cuda=self.is_cuda,