diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 4ca0952c804a7bc8b6039c451f65170e1f94228f..9b11c8d929aa68d35bf6d5fbe8384f7848958510 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile): def backend(self): return 'gpucuda' if self.is_cuda else 'c' - def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False): + def __init__(self, + module_name, + kernel_asts, + with_python_bindings=True, + wrap_wrapper_functions=False, + class_definitions=[]): """Create a C++ module with forward and optional backward_kernels :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) @@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile): 'module_name': module_name, 'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name, [self.PYTHON_FUNCTION_WRAPPING_CLASS(a) - for a in wrapper_functions]) + for a in wrapper_functions] + class_definitions) if with_python_bindings else '' }