Skip to content
Snippets Groups Projects
Commit 06103c82 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Bugfix: torch.stride is strides in element_size()

parent d78561a2
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ def _torch_tensor_to_numpy_shim(tensor):
from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy
fake_array = _WhatEverClass(
strides=[tensor.stride(i) for i in range(len(tensor.shape))],
strides=[tensor.stride(i) * tensor.storage().element_size() for i in range(len(tensor.shape))],
shape=tensor.shape,
dtype=torch_dtype_to_numpy(tensor.dtype))
return fake_array
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment