diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index edc4947b846c34f2b5ab280374a82cf157cdac25..af3899674f08284bd1b45a3dc3fec1b08990b79c 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -2,6 +2,7 @@ import numpy as np import sympy from pystencils import Field +from pystencils.field import FieldType class _WhatEverClass: @@ -10,9 +11,13 @@ class _WhatEverClass: class ArrayWithIndexDimensions: - def __init__(self, array, index_dimensions): + def __init__(self, + array, + index_dimensions, + field_type=FieldType.GENERIC): self.array = array self.index_dimensions = index_dimensions + self.field_type = field_type def __array__(self): return self.array @@ -34,9 +39,11 @@ def _torch_tensor_to_numpy_shim(tensor): def create_field_from_array_like(field_name, maybe_array): if isinstance(maybe_array, ArrayWithIndexDimensions): index_dimensions = maybe_array.index_dimensions + field_type = maybe_array.field_type maybe_array = maybe_array.array else: index_dimensions = 0 + field_type = FieldType.GENERIC try: import torch @@ -48,7 +55,9 @@ def create_field_from_array_like(field_name, maybe_array): if isinstance(maybe_array, torch.Tensor): maybe_array = _torch_tensor_to_numpy_shim(maybe_array) - return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) + field = Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) + field.field_type = field_type + return field def coerce_to_field(field_name, array_like):