From 59b6c5c4719d4de3f72003cbca6ca83f0d1161f1 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 10 Jan 2020 17:35:42 +0100
Subject: [PATCH] Form backward_assignments with spatial derivatives

---
 src/pystencils_reco/resampling.py |  7 +++++--
 tests/test_superresolution.py     | 21 +++++++++++++++++----
 2 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py
index d0c8ec4..86d58fa 100644
--- a/src/pystencils_reco/resampling.py
+++ b/src/pystencils_reco/resampling.py
@@ -121,7 +121,8 @@ def resample(input_field, output_field, interpolation_mode='linear'):
 def translate(input_field: pystencils.Field,
               output_field: pystencils.Field,
               translation,
-              interpolation_mode='linear'):
+              interpolation_mode='linear',
+              allow_spatial_derivatives=True):
 
     def create_autodiff(self, constant_fields=None, **kwargs):
         backward_assignments = translate(AdjointField(output_field), AdjointField(input_field), -translation)
@@ -136,7 +137,9 @@ def translate(input_field: pystencils.Field,
             output_field.center: input_field.interpolated_access(
                 input_field.physical_to_index(output_field.physical_coordinates - translation), interpolation_mode)
         })
-    assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
+
+    if not allow_spatial_derivatives:
+        assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
     return assignments
 
 
diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py
index 50f0ca9..648fb96 100644
--- a/tests/test_superresolution.py
+++ b/tests/test_superresolution.py
@@ -93,7 +93,7 @@ def test_polar_transform2():
     class PolarTransform(sympy.Function):
         def eval(args):
             return sympy.Matrix(
-                (args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2))
+                (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2))
 
     x.set_coordinate_origin_to_field_center()
     y.coordinate_transform = PolarTransform
@@ -117,11 +117,11 @@ def test_polar_inverted_transform():
     class PolarTransform(sympy.Function):
         def eval(args):
             return sympy.Matrix(
-                (args.norm(), sympy.atan2(args[1]-x.shape[1]/2, args[0]-x.shape[0]/2) / sympy.pi * x.shape[1]/2))
+                (args.norm(), sympy.atan2(args[1] - x.shape[1] / 2, args[0] - x.shape[0] / 2) / sympy.pi * x.shape[1] / 2))
 
         def inv():
-            return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1]*2) * l[0],
-                                            sympy.sin(l[1] * sympy.pi / x.shape[1]*2) * l[0]))
+            return lambda l: (sympy.Matrix((sympy.cos(l[1] * sympy.pi / x.shape[1] * 2) * l[0],
+                                            sympy.sin(l[1] * sympy.pi / x.shape[1] * 2) * l[0]))
                               + sympy.Matrix(x.shape) * 0.5)
 
     lenna_file = join(dirname(__file__), "test_data", "lenna.png")
@@ -251,3 +251,16 @@ def test_motion_model2():
 
     # while True:
     # sleep(100)
+
+
+def test_spatial_derivative():
+    x, y = pystencils.fields('x, y:  float32[2d]')
+    tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
+
+    assignments = pystencils.AssignmentCollection({
+        y.center: x.interpolated_access((tx.center, 2 * ty.center))
+    })
+
+    backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
+
+    print("backward_assignments: " + str(backward_assignments))
-- 
GitLab