From 32e36a69c59da8cab7cb1bad28dfdd4cb2acf2da Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 7 Oct 2020 13:20:08 +0200 Subject: [PATCH] Allow injecting class definitions into Python bindings --- src/pystencils_autodiff/backends/astnodes.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 4ca0952..9b11c8d 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile): def backend(self): return 'gpucuda' if self.is_cuda else 'c' - def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False): + def __init__(self, + module_name, + kernel_asts, + with_python_bindings=True, + wrap_wrapper_functions=False, + class_definitions=[]): """Create a C++ module with forward and optional backward_kernels :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) @@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile): 'module_name': module_name, 'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name, [self.PYTHON_FUNCTION_WRAPPING_CLASS(a) - for a in wrapper_functions]) + for a in wrapper_functions] + class_definitions) if with_python_bindings else '' } -- GitLab