diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 4c6c41d0f0eca8f5a4cff1908864e1dce29355b2..a84f7bfe93ac8ff4ae8dbb725bf25578c2871a95 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -14,7 +14,7 @@ from collections.abc import Iterable from os.path import dirname, exists, join from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol -from pystencils.cpu.cpujit import get_cache_config +from pystencils.cpu.cpujit import get_cache_config, get_compiler_config from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils_autodiff._file_io import read_template_from_file, write_file from pystencils_autodiff.backends.python_bindings import ( @@ -123,7 +123,7 @@ class TorchModule(JinjaCppFile): [file_name], with_cuda=self.is_cuda, extra_cflags=['--std=c++14'], - extra_cuda_cflags=['-std=c++14'], + extra_cuda_cflags=['-std=c++14', '--ccbin', get_compiler_config()['command']], build_directory=build_dir, extra_include_paths=[get_pycuda_include_path(), get_pystencils_include_path()])