diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index bb24dc510799309970b484692971b693ba50096c..9c968ccbb1f59e2ae97acfccd9c7baf501ab68ae 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -55,7 +55,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
     CLASS_TO_MEMBER_DICT = {
         FieldPointerSymbol: "mutable_data()",
         FieldShapeSymbol: "shape({dim})",
-        FieldStrideSymbol: "strides({dim})"
+        FieldStrideSymbol: "strides({dim}) / sizeof({dtype})"
     }
 
     CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>"
@@ -77,7 +77,10 @@ class TorchModule(JinjaCppFile):
         """
         if not isinstance(kernel_asts, Iterable):
             kernel_asts = [kernel_asts]
-        wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
+        wrapper_functions = [self.generate_wrapper_function(k)
+                             if not isinstance(k, WrapperFunction)
+                             else k for k in kernel_asts]
+        kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)]
         self.module_name = module_name
 
         ast_dict = {