diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index ab1a05eb7177757d35412335f861819b1451126b..99bd6d751707a631cf0e6538a45ffce7b17163cc 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -13,9 +13,9 @@ import subprocess import sysconfig from os.path import exists, join +import p_tqdm import pystencils from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path -from tqdm import tqdm from pystencils_autodiff._file_io import read_file, write_file @@ -156,37 +156,6 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T return destination_file -def compile_sources(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]): - - object_files = [] - - for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'): - is_cuda = source in cuda_sources - - if exists(source): - source_code = read_file(source) - else: - source_code = source - - file_extension = '.cu' if is_cuda else '.cpp' - file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}') - write_file(file_name, source_code) - - compile_file(file_name, - use_nvcc=is_cuda, - overwrite_destination_file=False, - additional_compile_flags=additional_compile_flags) - object_files.append(file_name + _object_file_extension) - - print('Linking Tensorflow module...') - module = link_and_load(object_files, - overwrite_destination_file=False, - additional_link_flags=additional_link_flags) - if module: - print('Loaded Tensorflow module.') - return module - - def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], @@ -194,8 +163,11 @@ def compile_sources_and_load(host_sources, compile_only=False): object_files = [] + sources = host_sources + cuda_sources - for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'): + print('Compiling Tensorflow module...') + + def compile(source): is_cuda = source in cuda_sources if exists(source): @@ -211,7 +183,10 @@ def compile_sources_and_load(host_sources, use_nvcc=is_cuda, overwrite_destination_file=False, additional_compile_flags=additional_compile_flags) - object_files.append(object_file) + return object_file + + # p_tqdm is just a parallel tqdm + object_files = p_tqdm.p_umap(compile, sources) print('Linking Tensorflow module...') module_file = link(object_files,