From 74deb68cef30472f57e30a97b9266e7020c36344 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 27 Feb 2020 10:31:38 +0100 Subject: [PATCH] Add optional `framework_module_class` parameter for codegeneration --- src/pyronn_torch/codegen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyronn_torch/codegen.py b/src/pyronn_torch/codegen.py index c3bc413..314975f 100644 --- a/src/pyronn_torch/codegen.py +++ b/src/pyronn_torch/codegen.py @@ -154,7 +154,7 @@ except Exception as e: warnings.warn(str(e)) -def generate_shared_object(output_folder=None, source_files=None, show_code=False): +def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule): object_cache = get_cache_config()['object_cache'] @@ -177,7 +177,7 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals copyfile(s, dst) # Torch only accepts *.cu as CUDA cuda_sources.append(dst) - module = TorchModule(module_name, FUNCTIONS.values()) + module = framework_module_class(module_name, FUNCTIONS.values()) if show_code: pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter()) -- GitLab