From bdd0e1e637e7592960e9f92d4ba6e866893b68be Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 28 Oct 2019 17:30:08 +0100 Subject: [PATCH] Allow sole WrapperFunctions in TorchModule --- src/pystencils_autodiff/backends/astnodes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index bb24dc5..9c968cc 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -55,7 +55,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "mutable_data()", FieldShapeSymbol: "shape({dim})", - FieldStrideSymbol: "strides({dim})" + FieldStrideSymbol: "strides({dim}) / sizeof({dtype})" } CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>" @@ -77,7 +77,10 @@ class TorchModule(JinjaCppFile): """ if not isinstance(kernel_asts, Iterable): kernel_asts = [kernel_asts] - wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts] + wrapper_functions = [self.generate_wrapper_function(k) + if not isinstance(k, WrapperFunction) + else k for k in kernel_asts] + kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)] self.module_name = module_name ast_dict = { -- GitLab