From 74deb68cef30472f57e30a97b9266e7020c36344 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 27 Feb 2020 10:31:38 +0100
Subject: [PATCH] Add optional `framework_module_class` parameter for
 codegeneration

---
 src/pyronn_torch/codegen.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/pyronn_torch/codegen.py b/src/pyronn_torch/codegen.py
index c3bc413..314975f 100644
--- a/src/pyronn_torch/codegen.py
+++ b/src/pyronn_torch/codegen.py
@@ -154,7 +154,7 @@ except Exception as e:
     warnings.warn(str(e))
 
 
-def generate_shared_object(output_folder=None, source_files=None, show_code=False):
+def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule):
 
     object_cache = get_cache_config()['object_cache']
 
@@ -177,7 +177,7 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
         copyfile(s, dst)  # Torch only accepts *.cu as CUDA
         cuda_sources.append(dst)
 
-    module = TorchModule(module_name, FUNCTIONS.values())
+    module = framework_module_class(module_name, FUNCTIONS.values())
 
     if show_code:
         pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
-- 
GitLab