From bdd0e1e637e7592960e9f92d4ba6e866893b68be Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 28 Oct 2019 17:30:08 +0100
Subject: [PATCH] Allow sole WrapperFunctions in TorchModule

---
 src/pystencils_autodiff/backends/astnodes.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index bb24dc5..9c968cc 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -55,7 +55,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
     CLASS_TO_MEMBER_DICT = {
         FieldPointerSymbol: "mutable_data()",
         FieldShapeSymbol: "shape({dim})",
-        FieldStrideSymbol: "strides({dim})"
+        FieldStrideSymbol: "strides({dim}) / sizeof({dtype})"
     }
 
     CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>"
@@ -77,7 +77,10 @@ class TorchModule(JinjaCppFile):
         """
         if not isinstance(kernel_asts, Iterable):
             kernel_asts = [kernel_asts]
-        wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
+        wrapper_functions = [self.generate_wrapper_function(k)
+                             if not isinstance(k, WrapperFunction)
+                             else k for k in kernel_asts]
+        kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)]
         self.module_name = module_name
 
         ast_dict = {
-- 
GitLab