From 94381fb928e0607ead14804d0fa81e7d2b00292f Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 24 Feb 2020 17:05:47 +0100 Subject: [PATCH] Add pyronn_torch.codegen, load shared object in __init__ --- src/pyronn_torch/__init__.py | 10 ++++++ .../pyronn_torch/codegen.py | 36 ++++++++++++------- 2 files changed, 34 insertions(+), 12 deletions(-) rename codegen/generate_wrappers.py => src/pyronn_torch/codegen.py (93%) diff --git a/src/pyronn_torch/__init__.py b/src/pyronn_torch/__init__.py index 7452884..e76fcf2 100644 --- a/src/pyronn_torch/__init__.py +++ b/src/pyronn_torch/__init__.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- +from os.path import dirname, join + +import torch from pkg_resources import DistributionNotFound, get_distribution try: @@ -11,3 +14,10 @@ except DistributionNotFound: finally: del get_distribution, DistributionNotFound + +try: + cpp_extension = torch.ops.load_library(join(dirname(__file__), 'pyronn_torch.so')) +except Exception: + import pyronn_torch.codegen + cpp_extension = pyronn_torch.codegen.compile_shared_object() + diff --git a/codegen/generate_wrappers.py b/src/pyronn_torch/codegen.py similarity index 93% rename from codegen/generate_wrappers.py rename to src/pyronn_torch/codegen.py index 9d28436..4e17c76 100644 --- a/codegen/generate_wrappers.py +++ b/src/pyronn_torch/codegen.py @@ -146,24 +146,25 @@ void Parallel_Backprojection2D_Kernel_Launcher(const float *sinogram_ptr, float } -def main(): +def generate_shared_object(output_folder=None, source_files=None, show_code=False): - parser = argparse.ArgumentParser() - parser.add_argument('--output-folder', default=join(dirname(__file__), '..', 'src', 'pyronn_torch')) - parser.add_argument('--source-files', default=glob(join(dirname(__file__), - '..', 'src', 'pyronn_torch', 'PYRO-NN-Layers', '*.cu.cc'))) - args = parser.parse_args() object_cache = get_cache_config()['object_cache'] module_name = 'PYRO_NN' + if not output_folder: + output_folder = dirname(__file__) + + if not source_files: + source_files = glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc')) + cuda_sources = [] makedirs(join(object_cache, module_name), exist_ok=True) - rmtree(join(object_cache, module_name, 'helper_headers')) - copytree(join(dirname(__file__), '..', 'src', 'pyronn_torch', - 'PYRO-NN-Layers', 'helper_headers'), join(object_cache, module_name, 'helper_headers')) + rmtree(join(object_cache, module_name, 'helper_headers'), ignore_errors=True) + copytree(join(dirname(__file__), 'PYRO-NN-Layers', 'helper_headers'), + join(object_cache, module_name, 'helper_headers')) - for s in args.source_files: + for s in source_files: dst = join(object_cache, module_name, basename(s).replace('.cu.cc', '.cu')) copyfile(s, dst) # Torch only accepts *.cu as CUDA cuda_sources.append(dst) @@ -171,7 +172,8 @@ def main(): functions = [WrapperFunction(Block([v]), function_name=k) for k, v in FUNCTIONS.items()] module = TorchModule(module_name, functions, wrap_wrapper_functions=True) - pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter()) + if show_code: + pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter()) extension = module.compile(extra_source_files=cuda_sources, extra_cuda_flags=['-arch=sm_35'], with_cuda=True) @@ -180,7 +182,17 @@ def main(): print(v.__doc__) shared_object_file = module.compiled_file.replace('.cpp', '.so') - copyfile(shared_object_file, join(args.output_folder, 'pyronn_torch.so')) + copyfile(shared_object_file, join(output_folder, 'pyronn_torch.so')) + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument('--output-folder', default=None) + parser.add_argument('--source-files', default=None) + args = parser.parse_args() + + generate_shared_object(args.output_folder, args.source_files, show_code=True) if __name__ == '__main__': -- GitLab