Skip to content
Snippets Groups Projects
Commit f280ee36 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add ArrayWithIndexDimensions

parent dcc48904
Branches
Tags
No related merge requests found
Pipeline #18972 failed
......@@ -9,6 +9,18 @@ class _WhatEverClass:
self.__dict__.update(kwargs)
class ArrayWithIndexDimensions:
def __init__(self, array, index_dimensions):
self.array = array
self.index_dimensions = index_dimensions
def __array__(self):
return self.array
def __getattr__(self, name):
return getattr(self.array, name)
def _torch_tensor_to_numpy_shim(tensor):
from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy
......@@ -20,6 +32,12 @@ def _torch_tensor_to_numpy_shim(tensor):
def create_field_from_array_like(field_name, maybe_array):
if isinstance(maybe_array, ArrayWithIndexDimensions):
index_dimensions = maybe_array.index_dimensions
maybe_array = maybe_array.array
else:
index_dimensions = 0
try:
import torch
except ImportError:
......@@ -30,10 +48,6 @@ def create_field_from_array_like(field_name, maybe_array):
if isinstance(maybe_array, torch.Tensor):
maybe_array = _torch_tensor_to_numpy_shim(maybe_array)
try:
index_dimensions = maybe_array.index_dimensions
except AttributeError:
index_dimensions = 0
return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment