From f13ac79ef167dc7342bd6c4db9c5090d89861f8d Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 8 Jan 2020 14:28:17 +0100 Subject: [PATCH] Add resample_to_shape --- src/pystencils_reco/resampling.py | 17 ++++++++++++++++- src/pystencils_reco/transforms.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 src/pystencils_reco/transforms.py diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index 0021c99..d0c8ec4 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -10,6 +10,7 @@ Implements common resampling operations like rotations and scalings import types from collections.abc import Iterable +from typing import Union import sympy @@ -38,7 +39,7 @@ def generic_spatial_matrix_transform(input_field, inverse_matrix @ output_field.physical_coordinates, staggered=False) assignments = AssignmentCollection({ - output_field.center(): + output_field.center: texture.at(output_coordinate) }) @@ -184,3 +185,17 @@ def downsample(input: {'field_type': pystencils.field.FieldType.CUSTOM}, assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments + + +@crazy +def resample_to_shape(input, + spatial_shape: Union[tuple, pystencils.Field], + ): + if hasattr(spatial_shape, 'spatial_shape'): + spatial_shape = spatial_shape.spatial_shape + + output_field = pystencils.Field.create_fixed_size( + 'output', spatial_shape + input.index_shape, input.index_dimensions, input.dtype.numpy_dtype) + output_field.coordinate_transform = sympy.DiagMatrix(sympy.Matrix([input.spatial_shape[i] / s + for i, s in enumerate(spatial_shape)])) + return resample(input, output_field) diff --git a/src/pystencils_reco/transforms.py b/src/pystencils_reco/transforms.py new file mode 100644 index 0000000..7fb7295 --- /dev/null +++ b/src/pystencils_reco/transforms.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +import sympy as sp + +import pystencils + + +def extend_to_size_of_other_field(this_field: pystencils.Field, other_field: pystencils.Field): + this_field.coordinate_transform = sp.DiagMatrix(sp.Matrix([this_field.spatial_shape[i] + / other_field.spatial_shape[i] + for i in range(len(this_field.spatial_shape))])) -- GitLab