diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index bb24dc510799309970b484692971b693ba50096c..9c968ccbb1f59e2ae97acfccd9c7baf501ab68ae 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -55,7 +55,7 @@ class PybindArrayDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "mutable_data()", FieldShapeSymbol: "shape({dim})", - FieldStrideSymbol: "strides({dim})" + FieldStrideSymbol: "strides({dim}) / sizeof({dtype})" } CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>" @@ -77,7 +77,10 @@ class TorchModule(JinjaCppFile): """ if not isinstance(kernel_asts, Iterable): kernel_asts = [kernel_asts] - wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts] + wrapper_functions = [self.generate_wrapper_function(k) + if not isinstance(k, WrapperFunction) + else k for k in kernel_asts] + kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)] self.module_name = module_name ast_dict = {