From 32e36a69c59da8cab7cb1bad28dfdd4cb2acf2da Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 7 Oct 2020 13:20:08 +0200
Subject: [PATCH] Allow injecting class definitions into Python bindings

---
 src/pystencils_autodiff/backends/astnodes.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 4ca0952..9b11c8d 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -102,7 +102,12 @@ class TorchModule(JinjaCppFile):
     def backend(self):
         return 'gpucuda' if self.is_cuda else 'c'
 
-    def __init__(self, module_name, kernel_asts, with_python_bindings=True, wrap_wrapper_functions=False):
+    def __init__(self,
+                 module_name,
+                 kernel_asts,
+                 with_python_bindings=True,
+                 wrap_wrapper_functions=False,
+                 class_definitions=[]):
         """Create a C++ module with forward and optional backward_kernels
 
         :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
@@ -125,7 +130,7 @@ class TorchModule(JinjaCppFile):
             'module_name': module_name,
             'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
                                                           [self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
-                                                              for a in wrapper_functions])
+                                                              for a in wrapper_functions] + class_definitions)
             if with_python_bindings else ''
         }
 
-- 
GitLab