Skip to content
Snippets Groups Projects
Commit 94381fb9 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add pyronn_torch.codegen, load shared object in __init__

parent 7a2153e7
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
from os.path import dirname, join
import torch
from pkg_resources import DistributionNotFound, get_distribution
try:
......@@ -11,3 +14,10 @@ except DistributionNotFound:
finally:
del get_distribution, DistributionNotFound
try:
cpp_extension = torch.ops.load_library(join(dirname(__file__), 'pyronn_torch.so'))
except Exception:
import pyronn_torch.codegen
cpp_extension = pyronn_torch.codegen.compile_shared_object()
......@@ -146,24 +146,25 @@ void Parallel_Backprojection2D_Kernel_Launcher(const float *sinogram_ptr, float
}
def main():
def generate_shared_object(output_folder=None, source_files=None, show_code=False):
parser = argparse.ArgumentParser()
parser.add_argument('--output-folder', default=join(dirname(__file__), '..', 'src', 'pyronn_torch'))
parser.add_argument('--source-files', default=glob(join(dirname(__file__),
'..', 'src', 'pyronn_torch', 'PYRO-NN-Layers', '*.cu.cc')))
args = parser.parse_args()
object_cache = get_cache_config()['object_cache']
module_name = 'PYRO_NN'
if not output_folder:
output_folder = dirname(__file__)
if not source_files:
source_files = glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc'))
cuda_sources = []
makedirs(join(object_cache, module_name), exist_ok=True)
rmtree(join(object_cache, module_name, 'helper_headers'))
copytree(join(dirname(__file__), '..', 'src', 'pyronn_torch',
'PYRO-NN-Layers', 'helper_headers'), join(object_cache, module_name, 'helper_headers'))
rmtree(join(object_cache, module_name, 'helper_headers'), ignore_errors=True)
copytree(join(dirname(__file__), 'PYRO-NN-Layers', 'helper_headers'),
join(object_cache, module_name, 'helper_headers'))
for s in args.source_files:
for s in source_files:
dst = join(object_cache, module_name, basename(s).replace('.cu.cc', '.cu'))
copyfile(s, dst) # Torch only accepts *.cu as CUDA
cuda_sources.append(dst)
......@@ -171,7 +172,8 @@ def main():
functions = [WrapperFunction(Block([v]), function_name=k) for k, v in FUNCTIONS.items()]
module = TorchModule(module_name, functions, wrap_wrapper_functions=True)
pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
if show_code:
pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
extension = module.compile(extra_source_files=cuda_sources, extra_cuda_flags=['-arch=sm_35'], with_cuda=True)
......@@ -180,7 +182,17 @@ def main():
print(v.__doc__)
shared_object_file = module.compiled_file.replace('.cpp', '.so')
copyfile(shared_object_file, join(args.output_folder, 'pyronn_torch.so'))
copyfile(shared_object_file, join(output_folder, 'pyronn_torch.so'))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output-folder', default=None)
parser.add_argument('--source-files', default=None)
args = parser.parse_args()
generate_shared_object(args.output_folder, args.source_files, show_code=True)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment