diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index a040b060adbe957e0e092e2ebc0617a47b63a576..4ca0952c804a7bc8b6039c451f65170e1f94228f 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -122,6 +122,7 @@ class TorchModule(JinjaCppFile): ast_dict = { 'kernels': kernel_asts, 'kernel_wrappers': wrapper_functions, + 'module_name': module_name, 'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name, [self.PYTHON_FUNCTION_WRAPPING_CLASS(a) for a in wrapper_functions])