From f1caa29c671f39253ebbb35e9156fa0e2ef4eb6f Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 4 Feb 2020 16:58:54 +0100 Subject: [PATCH] Use automatic calculation of adjoint in resampling --- src/pystencils_reco/resampling.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index 86d58fa..c9660c8 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -56,7 +56,7 @@ def generic_spatial_matrix_transform(input_field, self._autodiff = pystencils.autodiff.AutoDiffOp( assignments, "", backward_assignments=backward_assignments, **kwargs) - assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + # assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments @@ -86,6 +86,9 @@ def rotation_transform(input_field, rotation_angle, rotation_axis=None, interpolation_mode='linear'): + if isinstance(rotation_angle, pystencils.Field): + rotation_angle = rotation_angle.center + if input_field.spatial_dimensions == 3: assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!" transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle) @@ -124,10 +127,10 @@ def translate(input_field: pystencils.Field, interpolation_mode='linear', allow_spatial_derivatives=True): - def create_autodiff(self, constant_fields=None, **kwargs): - backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) - self._autodiff = pystencils.autodiff.AutoDiffOp( - assignments, "", backward_assignments=backward_assignments, **kwargs) + # def create_autodiff(self, constant_fields=None, **kwargs): + # backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) + # self._autodiff = pystencils.autodiff.AutoDiffOp( + # assignments, "", backward_assignments=backward_assignments, **kwargs) if isinstance(translation, pystencils.Field): translation = translation.center_vector @@ -138,8 +141,8 @@ def translate(input_field: pystencils.Field, input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) }) - if not allow_spatial_derivatives: - assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + # if not allow_spatial_derivatives: + # assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments -- GitLab