Skip to content
Snippets Groups Projects
Commit 32e36a69 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow injecting class definitions into Python bindings

parent 32bb1d31
Branches
Tags
No related merge requests found
Pipeline #27122 failed with stage
in 8 minutes and 20 seconds
...@@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile): ...@@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile):
def backend(self): def backend(self):
return 'gpucuda' if self.is_cuda else 'c' return 'gpucuda' if self.is_cuda else 'c'
def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False): def __init__(self,
module_name,
kernel_asts,
with_python_bindings=True,
wrap_wrapper_functions=False,
class_definitions=[]):
"""Create a C++ module with forward and optional backward_kernels """Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
...@@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile): ...@@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile):
'module_name': module_name, 'module_name': module_name,
'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name, 'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
[self.PYTHON_FUNCTION_WRAPPING_CLASS(a) [self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
for a in wrapper_functions]) for a in wrapper_functions] + class_definitions)
if with_python_bindings else '' if with_python_bindings else ''
} }
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment