diff --git a/src/pystencils_autodiff/tensorflow_jit.py b/src/pystencils_autodiff/tensorflow_jit.py index 6b8fcaeeb7119b77ad2b1e79b0c4323e206ddb43..05a12695622ce67142a10ab7efaf7fa2ec8a9dc7 100644 --- a/src/pystencils_autodiff/tensorflow_jit.py +++ b/src/pystencils_autodiff/tensorflow_jit.py @@ -161,12 +161,16 @@ def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_f file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}') write_file(file_name, source_code) - compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False, + compile_file(file_name, + use_nvcc=is_cuda, + overwrite_destination_file=False, additional_compile_flags=additional_compile_flags) object_files.append(file_name + '.o') print('Linking Tensorflow module...') - module = link_and_load(object_files, overwrite_destination_file=False, additional_link_flags=additional_link_flags) + module = link_and_load(object_files, + overwrite_destination_file=False, + additional_link_flags=additional_link_flags) if module: print('Loaded Tensorflow module.') return module