From 39af247c83948bb2e734f9a8ebd93678ba81c821 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 17 Oct 2019 16:37:05 +0200 Subject: [PATCH] Support index_dimensions In create_field_from_array_like --- src/pystencils_autodiff/field_tensor_conversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index f819b34..1e41191 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -19,7 +19,7 @@ def _torch_tensor_to_numpy_shim(tensor): return fake_array -def _create_field_from_array_like(field_name, maybe_array): +def create_field_from_array_like(field_name, maybe_array): try: import torch except ImportError: @@ -30,13 +30,17 @@ def _create_field_from_array_like(field_name, maybe_array): if isinstance(maybe_array, torch.Tensor): maybe_array = _torch_tensor_to_numpy_shim(maybe_array) - return Field.create_from_numpy_array(field_name, maybe_array) + try: + index_dimensions = maybe_array.index_dimensions + except AttributeError: + index_dimensions = 0 + return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) def coerce_to_field(field_name, array_like): if isinstance(array_like, Field): return array_like.new_field_with_different_name(field_name, array_like) - return _create_field_from_array_like(field_name, array_like) + return create_field_from_array_like(field_name, array_like) def is_array_like(a): -- GitLab