diff --git a/src/pyronn_torch/codegen.py b/src/pyronn_torch/codegen.py
index c3bc4134d1b1c974d89723918c4dd52e420e52ab..314975fa8a36f8254d05b9c856c1eaae64fd8787 100644
--- a/src/pyronn_torch/codegen.py
+++ b/src/pyronn_torch/codegen.py
@@ -154,7 +154,7 @@ except Exception as e:
     warnings.warn(str(e))
 
 
-def generate_shared_object(output_folder=None, source_files=None, show_code=False):
+def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule):
 
     object_cache = get_cache_config()['object_cache']
 
@@ -177,7 +177,7 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
         copyfile(s, dst)  # Torch only accepts *.cu as CUDA
         cuda_sources.append(dst)
 
-    module = TorchModule(module_name, FUNCTIONS.values())
+    module = framework_module_class(module_name, FUNCTIONS.values())
 
     if show_code:
         pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())