From f280ee365838769a0e83bf374c5e45beafad32cb Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 17 Oct 2019 17:24:09 +0200
Subject: [PATCH] Add ArrayWithIndexDimensions

---
 .../field_tensor_conversion.py                | 22 +++++++++++++++----
 1 file changed, 18 insertions(+), 4 deletions(-)

diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py
index 1e41191..edc4947 100644
--- a/src/pystencils_autodiff/field_tensor_conversion.py
+++ b/src/pystencils_autodiff/field_tensor_conversion.py
@@ -9,6 +9,18 @@ class _WhatEverClass:
         self.__dict__.update(kwargs)
 
 
+class ArrayWithIndexDimensions:
+    def __init__(self, array, index_dimensions):
+        self.array = array
+        self.index_dimensions = index_dimensions
+
+    def __array__(self):
+        return self.array
+
+    def __getattr__(self, name):
+        return getattr(self.array, name)
+
+
 def _torch_tensor_to_numpy_shim(tensor):
 
     from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy
@@ -20,6 +32,12 @@ def _torch_tensor_to_numpy_shim(tensor):
 
 
 def create_field_from_array_like(field_name, maybe_array):
+    if isinstance(maybe_array, ArrayWithIndexDimensions):
+        index_dimensions = maybe_array.index_dimensions
+        maybe_array = maybe_array.array
+    else:
+        index_dimensions = 0
+
     try:
         import torch
     except ImportError:
@@ -30,10 +48,6 @@ def create_field_from_array_like(field_name, maybe_array):
         if isinstance(maybe_array, torch.Tensor):
             maybe_array = _torch_tensor_to_numpy_shim(maybe_array)
 
-    try:
-        index_dimensions = maybe_array.index_dimensions
-    except AttributeError:
-        index_dimensions = 0
     return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions)
 
 
-- 
GitLab