Skip to content
Snippets Groups Projects
Commit 74deb68c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add optional `framework_module_class` parameter for codegeneration

parent e4224cce
Branches reproducible-hash-for-textures
Tags
No related merge requests found
......@@ -154,7 +154,7 @@ except Exception as e:
warnings.warn(str(e))
def generate_shared_object(output_folder=None, source_files=None, show_code=False):
def generate_shared_object(output_folder=None, source_files=None, show_code=False, framework_module_class=TorchModule):
object_cache = get_cache_config()['object_cache']
......@@ -177,7 +177,7 @@ def generate_shared_object(output_folder=None, source_files=None, show_code=Fals
copyfile(s, dst) # Torch only accepts *.cu as CUDA
cuda_sources.append(dst)
module = TorchModule(module_name, FUNCTIONS.values())
module = framework_module_class(module_name, FUNCTIONS.values())
if show_code:
pystencils.show_code(module, custom_backend=FrameworkIntegrationPrinter())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment