diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index f819b349dad6bf7533b32a80b199f0567f632d64..1e411915e370a0abd56ea7efd12f9bf251588383 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -19,7 +19,7 @@ def _torch_tensor_to_numpy_shim(tensor): return fake_array -def _create_field_from_array_like(field_name, maybe_array): +def create_field_from_array_like(field_name, maybe_array): try: import torch except ImportError: @@ -30,13 +30,17 @@ def _create_field_from_array_like(field_name, maybe_array): if isinstance(maybe_array, torch.Tensor): maybe_array = _torch_tensor_to_numpy_shim(maybe_array) - return Field.create_from_numpy_array(field_name, maybe_array) + try: + index_dimensions = maybe_array.index_dimensions + except AttributeError: + index_dimensions = 0 + return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) def coerce_to_field(field_name, array_like): if isinstance(array_like, Field): return array_like.new_field_with_different_name(field_name, array_like) - return _create_field_from_array_like(field_name, array_like) + return create_field_from_array_like(field_name, array_like) def is_array_like(a):