Skip to content
Snippets Groups Projects
Commit 74d93c0e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use pystencils compiler flags for Pybind11/Torch

parent 3b6ec0fd
No related branches found
No related tags found
No related merge requests found
...@@ -125,7 +125,7 @@ class TorchModule(JinjaCppFile): ...@@ -125,7 +125,7 @@ class TorchModule(JinjaCppFile):
torch_extension = load(hash, torch_extension = load(hash,
[file_name], [file_name],
with_cuda=self.is_cuda, with_cuda=self.is_cuda,
extra_cflags=['--std=c++14'], extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']], extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
build_directory=build_dir, build_directory=build_dir,
extra_include_paths=[get_pycuda_include_path(), extra_include_paths=[get_pycuda_include_path(),
...@@ -198,9 +198,19 @@ setup_pybind11(cfg) ...@@ -198,9 +198,19 @@ setup_pybind11(cfg)
if cache_dir not in sys.path: if cache_dir not in sys.path:
sys.path.append(cache_dir) sys.path.append(cache_dir)
# Torch regards CXX
os.environ['CXX'] = get_compiler_config()['command']
try: try:
torch_extension = cppimport.imp(f'cppimport_{hash_str}') torch_extension = cppimport.imp(f'cppimport_{hash_str}')
except Exception as e: except Exception as e:
print(e) print(e)
torch_extension = load(self.module_name, [file_name]) torch_extension = load(hash,
[file_name],
with_cuda=self.is_cuda,
extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
build_directory=cache_dir,
extra_include_paths=[get_pycuda_include_path(),
get_pystencils_include_path()])
return torch_extension return torch_extension
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment