Skip to content
Snippets Groups Projects
Commit f6a7e705 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make hashing work for tensorflow_jit

parent 3bcda22b
Branches
Tags
No related merge requests found
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
""" """
""" """
import hashlib
import json import json
import subprocess import subprocess
import sysconfig import sysconfig
...@@ -18,6 +19,9 @@ import pystencils ...@@ -18,6 +19,9 @@ import pystencils
from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path
from pystencils_autodiff._file_io import read_file, write_file from pystencils_autodiff._file_io import read_file, write_file
_hash = hashlib.md5
# TODO: msvc # TODO: msvc
if get_compiler_config()['os'] != 'windows': if get_compiler_config()['os'] != 'windows':
_shared_object_flag = '-shared' _shared_object_flag = '-shared'
...@@ -56,7 +60,8 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil ...@@ -56,7 +60,8 @@ def link_and_load(object_files, destination_file=None, overwrite_destination_fil
""" """
if not destination_file: if not destination_file:
destination_file = join(get_cache_config()['object_cache'], f"{abs(hash(tuple(object_files))):x}.so") destination_file = join(get_cache_config()['object_cache'],
f"{_hash('.'.join(sorted(object_files)).encode()).hexdigest()}.so")
if not exists(destination_file) or overwrite_destination_file: if not exists(destination_file) or overwrite_destination_file:
command = [get_compiler_config()['command'], command = [get_compiler_config()['command'],
...@@ -136,6 +141,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T ...@@ -136,6 +141,7 @@ def compile_file(file, use_nvcc=False, nvcc='nvcc', overwrite_destination_file=T
if not exists(destination_file) or overwrite_destination_file: if not exists(destination_file) or overwrite_destination_file:
subprocess.check_call(command) subprocess.check_call(command)
return destination_file return destination_file
...@@ -152,7 +158,7 @@ def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_f ...@@ -152,7 +158,7 @@ def compile_sources_and_load(host_sources, cuda_sources=[], additional_compile_f
source_code = source source_code = source
file_extension = '.cu' if is_cuda else '.cpp' file_extension = '.cu' if is_cuda else '.cpp'
file_name = join(pystencils.cache.cache_dir, f'{abs(hash(source_code)):x}{file_extension}') file_name = join(pystencils.cache.cache_dir, f'{_hash(source_code.encode()).hexdigest()}{file_extension}')
write_file(file_name, source_code) write_file(file_name, source_code)
compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False, compile_file(file_name, use_nvcc=is_cuda, overwrite_destination_file=False,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment