diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 4ca0952c804a7bc8b6039c451f65170e1f94228f..9b11c8d929aa68d35bf6d5fbe8384f7848958510 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile):
     def backend(self):
         return 'gpucuda' if self.is_cuda else 'c'
 
-    def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False):
+    def __init__(self,
+                 module_name,
+                 kernel_asts,
+                 with_python_bindings=True,
+                 wrap_wrapper_functions=False,
+                 class_definitions=[]):
         """Create a C++ module with forward and optional backward_kernels
 
         :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
@@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile):
             'module_name': module_name,
             'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
                                                           [self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
-                                                              for a in wrapper_functions])
+                                                              for a in wrapper_functions] + class_definitions)
             if with_python_bindings else ''
         }