Skip to content
Snippets Groups Projects
Commit 32e36a69 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow injecting class definitions into Python bindings

parent 32bb1d31
Branches
Tags
No related merge requests found
Pipeline #27122 failed
......@@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile):
def backend(self):
return 'gpucuda' if self.is_cuda else 'c'
def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False):
def __init__(self,
module_name,
kernel_asts,
with_python_bindings=True,
wrap_wrapper_functions=False,
class_definitions=[]):
"""Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
......@@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile):
'module_name': module_name,
'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
[self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
for a in wrapper_functions])
for a in wrapper_functions] + class_definitions)
if with_python_bindings else ''
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment