diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 6cd95960945d0000f297a8d8d75ef521f6e3e745..1384893b0f9c2c06b5f37142c0f36a02885dccf2 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -8,6 +8,7 @@ """ import hashlib +import os import subprocess import sysconfig from os.path import exists, join @@ -45,7 +46,7 @@ except ImportError: pass -def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]): +def link(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]): """Compiles given :param:`source_file` to a Tensorflow shared Library. .. warning:: @@ -75,11 +76,18 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil subprocess.check_call(command) + return destination_file + + +def link_and_load(object_files, destination_file=None, overwrite_destination_file=True, additional_link_flags=[]): + destination_file = link(object_files, destination_file, overwrite_destination_file, additional_link_flags) lib = tf.load_op_library(destination_file) return lib def try_get_cuda_arch_flag(): + if 'PYSTENCILS_TENSORFLOW_NVCC_ARCH' in os.environ: + return "-arch=sm_" + os.environ['PYSTENCILS_TENSORFLOW_NVCC_ARCH'] try: from pycuda.driver import Context arch = "sm_%d%d" % Context.get_device().compute_capability() @@ -143,7 +151,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T return destination_file -def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]): +def compile_sources(host_sources, cuda_sources=[], additional_compile_flags=[], additional_link_flags=[]): object_files = [] @@ -172,3 +180,42 @@ def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_f if module: print('Loaded Tensorflow module.') return module + + +def compile_sources_and_load(host_sources, + cuda_sources=[], + additional_compile_flags=[], + additional_link_flags=[], + compile_only=False): + + object_files = [] + + for source in tqdm(host_sources + cuda_sources, desc='Compiling Tensorflow module'): + is_cuda = source in cuda_sources + + if exists(source): + source_code = read_file(source) + else: + source_code = source + + file_extension = '.cu' if is_cuda else '.cpp' + file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}') + write_file(file_name, source_code) + + compile_file(file_name, + use_nvcc=is_cuda, + overwrite_destination_file=False, + additional_compile_flags=additional_compile_flags) + object_files.append(file_name + '.o') + + print('Linking Tensorflow module...') + module_file = link(object_files, + overwrite_destination_file=False, + additional_link_flags=additional_link_flags) + if not compile_only: + module = tf.load_op_library(module_file) + if module: + print('Loaded Tensorflow module.') + return module + else: + return module_file diff --git a/tests/test_tensorflow_jit.py b/tests/test_tensorflow_jit.py index 187aa96a925aa113ed7cec70286aa993836036ea..bea327e7c1cb0b4caefdb55bfe4951977200e308 100644 --- a/tests/test_tensorflow_jit.py +++ b/tests/test_tensorflow_jit.py @@ -8,6 +8,8 @@ """ +from os.path import exists + import pytest import sympy @@ -44,6 +46,10 @@ def test_tensorflow_jit_gpu(): assert 'call_forward_jit_gpu' in dir(lib) assert 'call_backward_jit_gpu' in dir(lib) + file_name = pystencils_autodiff.tensorflow_jit.compile_sources_and_load([], [str(module)], compile_only=True) + print(file_name) + assert exists(file_name) + def test_tensorflow_jit_cpu():