diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index af3899674f08284bd1b45a3dc3fec1b08990b79c..ca0ae54df5b39e2dd9d1340fa92ed7a73a7c7113 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -30,7 +30,7 @@ def _torch_tensor_to_numpy_shim(tensor): from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy fake_array = _WhatEverClass( - strides=[tensor.stride(i) for i in range(len(tensor.shape))], + strides=[tensor.stride(i) * tensor.storage().element_size() for i in range(len(tensor.shape))], shape=tensor.shape, dtype=torch_dtype_to_numpy(tensor.dtype)) return fake_array