diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index d0c8ec4f77f458d30567f3887bbc074a9e53e8eb..86d58fab6f1f577ce73b3dd4d6f9b5ff88648405 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -121,7 +121,8 @@ def resample(input_field, output_field, interpolation_mode='linear'): def translate(input_field: pystencils.Field, output_field: pystencils.Field, translation, - interpolation_mode='linear'): + interpolation_mode='linear', + allow_spatial_derivatives=True): def create_autodiff(self, constant_fields=None, **kwargs): backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) @@ -136,7 +137,9 @@ def translate(input_field: pystencils.Field, output_field.center: input_field.interpolated_access( input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) }) - assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + + if not allow_spatial_derivatives: + assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 50f0ca9721fbd82ae5e199a77ecd47e565fa733b..648fb96d3121a951e3267fdbca780f97d835ba6d 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -93,7 +93,7 @@ def test_polar_transform2(): class PolarTransform(sympy.Function): def eval(args): return sympy.Matrix( - (args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2)) + (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2)) x.set_coordinate_origin_to_field_center() y.coordinate_transform = PolarTransform @@ -117,11 +117,11 @@ def test_polar_inverted_transform(): class PolarTransform(sympy.Function): def eval(args): return sympy.Matrix( - (args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2)) + (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2)) def inv(): - return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1]*2) * l[0], - sympy.sin(l[1] * sympy.pi / x.shape[1]*2) * l[0])) + return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1] * 2) * l[0], + sympy.sin(l[1] * sympy.pi / x.shape[1] * 2) * l[0])) + sympy.Matrix(x.shape) * 0.5) lenna_file = join(dirname(__file__), "test_data", "lenna.png") @@ -251,3 +251,16 @@ def test_motion_model2(): # while True: # sleep(100) + + +def test_spatial_derivative(): + x, y = pystencils.fields('x, y: float32[2d]') + tx, ty = pystencils.fields('t_x, t_y: float32[2d]') + + assignments = pystencils.AssignmentCollection({ + y.center: x.interpolated_access((tx.center, 2 * ty.center)) + }) + + backward_assignments = pystencils.autodiff.create_backward_assignments(assignments) + + print("backward_assignments: " + str(backward_assignments))