diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index 86d58fab6f1f577ce73b3dd4d6f9b5ff88648405..c9660c8271b397f04f5015667f41d002cd4ff677 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -56,7 +56,7 @@ def generic_spatial_matrix_transform(input_field, self._autodiff = pystencils.autodiff.AutoDiffOp( assignments, "", backward_assignments=backward_assignments, **kwargs) - assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + # assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments @@ -86,6 +86,9 @@ def rotation_transform(input_field, rotation_angle, rotation_axis=None, interpolation_mode='linear'): + if isinstance(rotation_angle, pystencils.Field): + rotation_angle = rotation_angle.center + if input_field.spatial_dimensions == 3: assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!" transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle) @@ -124,10 +127,10 @@ def translate(input_field: pystencils.Field, interpolation_mode='linear', allow_spatial_derivatives=True): - def create_autodiff(self, constant_fields=None, **kwargs): - backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) - self._autodiff = pystencils.autodiff.AutoDiffOp( - assignments, "", backward_assignments=backward_assignments, **kwargs) + # def create_autodiff(self, constant_fields=None, **kwargs): + # backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) + # self._autodiff = pystencils.autodiff.AutoDiffOp( + # assignments, "", backward_assignments=backward_assignments, **kwargs) if isinstance(translation, pystencils.Field): translation = translation.center_vector @@ -138,8 +141,8 @@ def translate(input_field: pystencils.Field, input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) }) - if not allow_spatial_derivatives: - assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + # if not allow_spatial_derivatives: + # assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments