From 59b6c5c4719d4de3f72003cbca6ca83f0d1161f1 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 10 Jan 2020 17:35:42 +0100 Subject: [PATCH] Form backward_assignments with spatial derivatives --- src/pystencils_reco/resampling.py | 7 +++++-- tests/test_superresolution.py | 21 +++++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py index d0c8ec4..86d58fa 100644 --- a/src/pystencils_reco/resampling.py +++ b/src/pystencils_reco/resampling.py @@ -121,7 +121,8 @@ def resample(input_field, output_field, interpolation_mode='linear'): def translate(input_field: pystencils.Field, output_field: pystencils.Field, translation, - interpolation_mode='linear'): + interpolation_mode='linear', + allow_spatial_derivatives=True): def create_autodiff(self, constant_fields=None, **kwargs): backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation) @@ -136,7 +137,9 @@ def translate(input_field: pystencils.Field, output_field.center: input_field.interpolated_access( input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode) }) - assignments._create_autodiff = types.MethodType(create_autodiff, assignments) + + if not allow_spatial_derivatives: + assignments._create_autodiff = types.MethodType(create_autodiff, assignments) return assignments diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 50f0ca9..648fb96 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -93,7 +93,7 @@ def test_polar_transform2(): class PolarTransform(sympy.Function): def eval(args): return sympy.Matrix( - (args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2)) + (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2)) x.set_coordinate_origin_to_field_center() y.coordinate_transform = PolarTransform @@ -117,11 +117,11 @@ def test_polar_inverted_transform(): class PolarTransform(sympy.Function): def eval(args): return sympy.Matrix( - (args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2)) + (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2)) def inv(): - return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1]*2) * l[0], - sympy.sin(l[1] * sympy.pi / x.shape[1]*2) * l[0])) + return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1] * 2) * l[0], + sympy.sin(l[1] * sympy.pi / x.shape[1] * 2) * l[0])) + sympy.Matrix(x.shape) * 0.5) lenna_file = join(dirname(__file__), "test_data", "lenna.png") @@ -251,3 +251,16 @@ def test_motion_model2(): # while True: # sleep(100) + + +def test_spatial_derivative(): + x, y = pystencils.fields('x, y: float32[2d]') + tx, ty = pystencils.fields('t_x, t_y: float32[2d]') + + assignments = pystencils.AssignmentCollection({ + y.center: x.interpolated_access((tx.center, 2 * ty.center)) + }) + + backward_assignments = pystencils.autodiff.create_backward_assignments(assignments) + + print("backward_assignments: " + str(backward_assignments)) -- GitLab