From f280ee365838769a0e83bf374c5e45beafad32cb Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 17 Oct 2019 17:24:09 +0200 Subject: [PATCH] Add ArrayWithIndexDimensions --- .../field_tensor_conversion.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index 1e41191..edc4947 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -9,6 +9,18 @@ class _WhatEverClass: self.__dict__.update(kwargs) +class ArrayWithIndexDimensions: + def __init__(self, array, index_dimensions): + self.array = array + self.index_dimensions = index_dimensions + + def __array__(self): + return self.array + + def __getattr__(self, name): + return getattr(self.array, name) + + def _torch_tensor_to_numpy_shim(tensor): from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy @@ -20,6 +32,12 @@ def _torch_tensor_to_numpy_shim(tensor): def create_field_from_array_like(field_name, maybe_array): + if isinstance(maybe_array, ArrayWithIndexDimensions): + index_dimensions = maybe_array.index_dimensions + maybe_array = maybe_array.array + else: + index_dimensions = 0 + try: import torch except ImportError: @@ -30,10 +48,6 @@ def create_field_from_array_like(field_name, maybe_array): if isinstance(maybe_array, torch.Tensor): maybe_array = _torch_tensor_to_numpy_shim(maybe_array) - try: - index_dimensions = maybe_array.index_dimensions - except AttributeError: - index_dimensions = 0 return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) -- GitLab