diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index 1e411915e370a0abd56ea7efd12f9bf251588383..edc4947b846c34f2b5ab280374a82cf157cdac25 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -9,6 +9,18 @@ class _WhatEverClass: self.__dict__.update(kwargs) +class ArrayWithIndexDimensions: + def __init__(self, array, index_dimensions): + self.array = array + self.index_dimensions = index_dimensions + + def __array__(self): + return self.array + + def __getattr__(self, name): + return getattr(self.array, name) + + def _torch_tensor_to_numpy_shim(tensor): from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy @@ -20,6 +32,12 @@ def _torch_tensor_to_numpy_shim(tensor): def create_field_from_array_like(field_name, maybe_array): + if isinstance(maybe_array, ArrayWithIndexDimensions): + index_dimensions = maybe_array.index_dimensions + maybe_array = maybe_array.array + else: + index_dimensions = 0 + try: import torch except ImportError: @@ -30,10 +48,6 @@ def create_field_from_array_like(field_name, maybe_array): if isinstance(maybe_array, torch.Tensor): maybe_array = _torch_tensor_to_numpy_shim(maybe_array) - try: - index_dimensions = maybe_array.index_dimensions - except AttributeError: - index_dimensions = 0 return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions)