diff --git a/src/pyronn_torch/codegen.py b/src/pyronn_torch/codegen.py index c3bc4134d1b1c974d89723918c4dd52e420e52ab..314975fa8a36f8254d05b9c856c1eaae64fd8787 100644 --- a/src/pyronn_torch/codegen.py +++ b/src/pyronn_torch/codegen.py @@ -154,7 +154,7 @@ except Exception as e: warnings.warn(str(e)) -def generate_shared_object(output_folder=None, source_files=None, show_code=False): +def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule): object_cache = get_cache_config()['object_cache'] @@ -177,7 +177,7 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals copyfile(s, dst) # Torch only accepts *.cu as CUDA cuda_sources.append(dst) - module = TorchModule(module_name, FUNCTIONS.values()) + module = framework_module_class(module_name, FUNCTIONS.values()) if show_code: pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())