Skip to content
Snippets Groups Projects
Commit 8319600b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Enable parallel compilation in tensorflow_jit

parent 798d2f80
Branches
Tags
No related merge requests found
Pipeline #18030 failed
...@@ -13,9 +13,9 @@ import subprocess ...@@ -13,9 +13,9 @@ import subprocess
import sysconfig import sysconfig
from os.path import exists, join from os.path import exists, join
import p_tqdm
import pystencils import pystencils
from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path
from tqdm import tqdm
from pystencils_autodiff._file_io import read_file, write_file from pystencils_autodiff._file_io import read_file, write_file
...@@ -156,37 +156,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -156,37 +156,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
return destination_file return destination_file
def compile_sources(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]):
object_files = []
for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'):
is_cuda = source in cuda_sources
if exists(source):
source_code = read_file(source)
else:
source_code = source
file_extension = '.cu' if is_cuda else '.cpp'
file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}')
write_file(file_name, source_code)
compile_file(file_name,
use_nvcc=is_cuda,
overwrite_destination_file=False,
additional_compile_flags=additional_compile_flags)
object_files.append(file_name + _object_file_extension)
print('Linking Tensorflow module...')
module = link_and_load(object_files,
overwrite_destination_file=False,
additional_link_flags=additional_link_flags)
if module:
print('Loaded Tensorflow module.')
return module
def compile_sources_and_load(host_sources, def compile_sources_and_load(host_sources,
cuda_sources=[], cuda_sources=[],
additional_compile_flags=[], additional_compile_flags=[],
...@@ -194,8 +163,11 @@ def compile_sources_and_load(host_sources, ...@@ -194,8 +163,11 @@ def compile_sources_and_load(host_sources,
compile_only=False): compile_only=False):
object_files = [] object_files = []
sources = host_sources + cuda_sources
for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'): print('Compiling Tensorflow module...')
def compile(source):
is_cuda = source in cuda_sources is_cuda = source in cuda_sources
if exists(source): if exists(source):
...@@ -211,7 +183,10 @@ def compile_sources_and_load(host_sources, ...@@ -211,7 +183,10 @@ def compile_sources_and_load(host_sources,
use_nvcc=is_cuda, use_nvcc=is_cuda,
overwrite_destination_file=False, overwrite_destination_file=False,
additional_compile_flags=additional_compile_flags) additional_compile_flags=additional_compile_flags)
object_files.append(object_file) return object_file
# p_tqdm is just a parallel tqdm
object_files = p_tqdm.p_umap(compile, sources)
print('Linking Tensorflow module...') print('Linking Tensorflow module...')
module_file = link(object_files, module_file = link(object_files,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment