Skip to content
Snippets Groups Projects
Commit 74d93c0e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use pystencils compiler flags for Pybind11/Torch

parent 3b6ec0fd
Branches
Tags
No related merge requests found
......@@ -125,7 +125,7 @@ class TorchModule(JinjaCppFile):
torch_extension = load(hash,
[file_name],
with_cuda=self.is_cuda,
extra_cflags=['--std=c++14'],
extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
build_directory=build_dir,
extra_include_paths=[get_pycuda_include_path(),
......@@ -198,9 +198,19 @@ setup_pybind11(cfg)
if cache_dir not in sys.path:
sys.path.append(cache_dir)
# Torch regards CXX
os.environ['CXX'] = get_compiler_config()['command']
try:
torch_extension = cppimport.imp(f'cppimport_{hash_str}')
except Exception as e:
print(e)
torch_extension = load(self.module_name, [file_name])
torch_extension = load(hash,
[file_name],
with_cuda=self.is_cuda,
extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
build_directory=cache_dir,
extra_include_paths=[get_pycuda_include_path(),
get_pystencils_include_path()])
return torch_extension
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment