diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index b8487f1395b8617eced9425a8d9302119187c179..a040b060adbe957e0e092e2ebc0617a47b63a576 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -55,7 +55,7 @@ class Header(JinjaCppFile): class TorchTensorDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { - FieldPointerSymbol: "data<{dtype}>()", + FieldPointerSymbol: "data_ptr<{dtype}>()", FieldShapeSymbol: "size({dim})", FieldStrideSymbol: "strides()[{dim}]" }