diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index e51eb1b3e6a58293557c2628a82c2021ab792031..62eb5e9ecdc1765a4258e7329416a9d3df4edefe 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -125,7 +125,7 @@ class TorchModule(JinjaCppFile): torch_extension = load(hash, [file_name], with_cuda=self.is_cuda, - extra_cflags=['--std=c++14'], + extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')], extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']], build_directory=build_dir, extra_include_paths=[get_pycuda_include_path(), @@ -198,9 +198,19 @@ setup_pybind11(cfg) if cache_dir not in sys.path: sys.path.append(cache_dir) + # Torch regards CXX + os.environ['CXX'] = get_compiler_config()['command'] + try: torch_extension = cppimport.imp(f'cppimport_{hash_str}') except Exception as e: print(e) - torch_extension = load(self.module_name, [file_name]) + torch_extension = load(hash, + [file_name], + with_cuda=self.is_cuda, + extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')], + extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']], + build_directory=cache_dir, + extra_include_paths=[get_pycuda_include_path(), + get_pystencils_include_path()]) return torch_extension