diff --git a/src/pystencils_autodiff/field_tensor_conversion.py b/src/pystencils_autodiff/field_tensor_conversion.py index ae5fe58973a4c21e1efe6f686fab84bab9c6a43b..9e0b054bda31259e37266433c73d27218c0b11c9 100644 --- a/src/pystencils_autodiff/field_tensor_conversion.py +++ b/src/pystencils_autodiff/field_tensor_conversion.py @@ -10,14 +10,30 @@ class _WhatEverClass: self.__dict__.update(kwargs) -class ArrayWithIndexDimensions: +class ArrayWrapper: def __init__(self, array, - index_dimensions, - field_type=FieldType.GENERIC): + index_dimensions=0, + field_type=FieldType.GENERIC, + coordinate_transform=None, + spacing=None, + origin=None, + coordinate_origin=None): self.array = array self.index_dimensions = index_dimensions self.field_type = field_type + if spacing: + coordinate_transform = sympy.diag(*spacing) + if origin: + origin = sympy.Matrix(origin) + spacing = spacing or origin / origin + coordinate_transform = sympy.diag(*spacing) + coordinate_origin = origin / spacing + + if coordinate_transform: + self.coordinate_transform = coordinate_transform + if coordinate_origin: + self.coordinate_origin = coordinate_origin def __array__(self): return self.array @@ -40,7 +56,7 @@ def create_field_from_array_like(field_name, maybe_array, annotations=None): if annotations and isinstance(annotations, dict): index_dimensions = annotations.get('index_dimensions', 0) field_type = annotations.get('field_type', FieldType.GENERIC) - elif isinstance(maybe_array, ArrayWithIndexDimensions): + elif isinstance(maybe_array, ArrayWrapper): index_dimensions = maybe_array.index_dimensions field_type = maybe_array.field_type maybe_array = maybe_array.array @@ -65,6 +81,10 @@ def create_field_from_array_like(field_name, maybe_array, annotations=None): field = Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) field.field_type = field_type + if hasattr(maybe_array, 'coordinate_transform'): + field.coordinate_transform = maybe_array.coordinate_transform + if hasattr(maybe_array, 'coordinate_origin'): + field.coordinate_origin = maybe_array.coordinate_origin return field