From 94381fb928e0607ead14804d0fa81e7d2b00292f Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 24 Feb 2020 17:05:47 +0100
Subject: [PATCH] Add pyronn_torch.codegen, load shared object in __init__

---
 src/pyronn_torch/__init__.py                  | 10 ++++++
 .../pyronn_torch/codegen.py                   | 36 ++++++++++++-------
 2 files changed, 34 insertions(+), 12 deletions(-)
 rename codegen/generate_wrappers.py => src/pyronn_torch/codegen.py (93%)

diff --git a/src/pyronn_torch/__init__.py b/src/pyronn_torch/__init__.py
index 7452884..e76fcf2 100644
--- a/src/pyronn_torch/__init__.py
+++ b/src/pyronn_torch/__init__.py
@@ -1,5 +1,8 @@
 # -*- coding: utf-8 -*-
 
+from os.path import dirname, join
+
+import torch
 from pkg_resources import DistributionNotFound, get_distribution
 
 try:
@@ -11,3 +14,10 @@ except DistributionNotFound:
 finally:
     del get_distribution, DistributionNotFound
 
+
+try:
+    cpp_extension = torch.ops.load_library(join(dirname(__file__), 'pyronn_torch.so'))
+except Exception:
+    import pyronn_torch.codegen
+    cpp_extension = pyronn_torch.codegen.compile_shared_object()
+
diff --git a/codegen/generate_wrappers.py b/src/pyronn_torch/codegen.py
similarity index 93%
rename from codegen/generate_wrappers.py
rename to src/pyronn_torch/codegen.py
index 9d28436..4e17c76 100644
--- a/codegen/generate_wrappers.py
+++ b/src/pyronn_torch/codegen.py
@@ -146,24 +146,25 @@ void Parallel_Backprojection2D_Kernel_Launcher(const float *sinogram_ptr, float
         }
 
 
-def main():
+def generate_shared_object(output_folder=None, source_files=None, show_code=False):
 
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--output-folder',  default=join(dirname(__file__), '..', 'src', 'pyronn_torch'))
-    parser.add_argument('--source-files',  default=glob(join(dirname(__file__),
-                                                             '..', 'src', 'pyronn_torch', 'PYRO-NN-Layers', '*.cu.cc')))
-    args = parser.parse_args()
     object_cache = get_cache_config()['object_cache']
 
     module_name = 'PYRO_NN'
 
+    if not output_folder:
+        output_folder = dirname(__file__)
+
+    if not source_files:
+        source_files = glob(join(dirname(__file__), 'PYRO-NN-Layers', '*.cu.cc'))
+
     cuda_sources = []
     makedirs(join(object_cache, module_name), exist_ok=True)
-    rmtree(join(object_cache, module_name, 'helper_headers'))
-    copytree(join(dirname(__file__), '..', 'src', 'pyronn_torch',
-                  'PYRO-NN-Layers', 'helper_headers'), join(object_cache, module_name, 'helper_headers'))
+    rmtree(join(object_cache, module_name, 'helper_headers'), ignore_errors=True)
+    copytree(join(dirname(__file__), 'PYRO-NN-Layers', 'helper_headers'),
+             join(object_cache, module_name, 'helper_headers'))
 
-    for s in args.source_files:
+    for s in source_files:
         dst = join(object_cache, module_name, basename(s).replace('.cu.cc', '.cu'))
         copyfile(s, dst)  # Torch only accepts *.cu as CUDA
         cuda_sources.append(dst)
@@ -171,7 +172,8 @@ def main():
     functions = [WrapperFunction(Block([v]), function_name=k) for k, v in FUNCTIONS.items()]
     module = TorchModule(module_name, functions, wrap_wrapper_functions=True)
 
-    pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
+    if show_code:
+        pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
 
     extension = module.compile(extra_source_files=cuda_sources, extra_cuda_flags=['-arch=sm_35'], with_cuda=True)
 
@@ -180,7 +182,17 @@ def main():
             print(v.__doc__)
 
     shared_object_file = module.compiled_file.replace('.cpp', '.so')
-    copyfile(shared_object_file, join(args.output_folder, 'pyronn_torch.so'))
+    copyfile(shared_object_file, join(output_folder, 'pyronn_torch.so'))
+
+
+def main():
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--output-folder',  default=None)
+    parser.add_argument('--source-files',  default=None)
+    args = parser.parse_args()
+
+    generate_shared_object(args.output_folder, args.source_files, show_code=True)
 
 
 if __name__ == '__main__':
-- 
GitLab