diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index 0021c990ae8737b0924572da15e7dd23b3745b66..d0c8ec4f77f458d30567f3887bbc074a9e53e8eb 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -10,6 +10,7 @@ Implements common resampling operations like rotations and scalings import types from collections.abc import Iterable +from typing import Union import sympy @@ -38,7 +39,7 @@ def generic_spatial_matrix_transform(input_field, inverse_matrix @ output_field.physical_coordinates, staggered=False) assignments = AssignmentCollection({ - output_field.center(): + output_field.center: texture.at(output_coordinate) }) @@ -184,3 +185,17 @@ def downsample(input: {'field_type': pystencils.field.FieldType.CUSTOM}, assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments + + +@crazy +def resample_to_shape(input, + spatial_shape: Union[tuple, pystencils.Field], + ): + if hasattr(spatial_shape, 'spatial_shape'): + spatial_shape = spatial_shape.spatial_shape + + output_field = pystencils.Field.create_fixed_size( + 'output', spatial_shape + input.index_shape, input.index_dimensions, input.dtype.numpy_dtype) + output_field.coordinate_transform = sympy.DiagMatrix(sympy.Matrix([input.spatial_shape[i] / s + for i, s in enumerate(spatial_shape)])) + return resample(input, output_field) diff --git a/src/pystencils_reco/transforms.py b/src/pystencils_reco/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb7295faba5d8a364465940b2bc0b79bb512ca5 --- /dev/null +++ b/src/pystencils_reco/transforms.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +import sympy as sp + +import pystencils + + +def extend_to_size_of_other_field(this_field: pystencils.Field, other_field: pystencils.Field): + this_field.coordinate_transform = sp.DiagMatrix(sp.Matrix([this_field.spatial_shape[i] + / other_field.spatial_shape[i] + for i in range(len(this_field.spatial_shape))]))