From af7bfff213b1ed152186440af5ebc7950e7d7035 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 14 Jan 2020 18:41:54 +0100
Subject: [PATCH] Refactor interpolation

---
 src/pystencils_reco/_assignment_collection.py |  3 +
 tests/test_superresolution.py                 | 83 ++++++++++++++++++-
 2 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py
index 82ff7c1..e236e03 100644
--- a/src/pystencils_reco/_assignment_collection.py
+++ b/src/pystencils_reco/_assignment_collection.py
@@ -181,6 +181,9 @@ class AssignmentCollection(pystencils.AssignmentCollection):
                                 if hasattr(t, 'requires_grad') and not t.requires_grad]
         constant_fields = {f for f in self.free_fields if f.name in constant_field_names}
 
+        for n in [f for f, t in kwargs.items() if hasattr(t, 'requires_grad')]:
+            kwargs.pop(n)
+
         if not self._autodiff:
             if hasattr(self, '_create_autodiff'):
                 self._create_autodiff(constant_fields, **kwargs)
diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py
index 473c58c..0967514 100644
--- a/tests/test_superresolution.py
+++ b/tests/test_superresolution.py
@@ -259,11 +259,29 @@ def test_spatial_derivative():
     tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
 
     assignments = pystencils.AssignmentCollection({
-        y.center: x.interpolated_access((tx.center, 2 * ty.center))
+        y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_))
     })
 
     backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
 
+    print("assignments: " + str(assignments))
+    print("backward_assignments: " + str(backward_assignments))
+
+
+def test_spatial_derivative2():
+    import pystencils.interpolation_astnodes
+    x, y = pystencils.fields('x, y:  float32[2d]')
+    tx, ty = pystencils.fields('t_x, t_y: float32[2d]')
+
+    assignments = pystencils.AssignmentCollection({
+        y.center: x.interpolated_access((tx.center + pystencils.x_, ty.center + 2 * pystencils.y_))
+    })
+
+    backward_assignments = pystencils.autodiff.create_backward_assignments(assignments)
+
+    assert backward_assignments.atoms(pystencils.interpolation_astnodes.DiffInterpolatorAccess)
+
+    print("assignments: " + str(assignments))
     print("backward_assignments: " + str(backward_assignments))
 
 
@@ -282,9 +300,70 @@ def test_get_shift():
     dh.all_to_gpu()
 
     kernel = pystencils_reco.AssignmentCollection({
-        y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_))
+        y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center +
+                                         pystencils.y_), interpolation_mode='cubic_spline')
     }).create_pytorch_op()().forward
 
     dh.run_kernel(kernel)
 
     pyconrad.imshow(dh.gpu_arrays)
+
+
+def test_get_shift_tensors():
+    from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling
+    import torch
+
+    lenna_file = join(dirname(__file__), "test_data", "lenna.png")
+    lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32)
+
+    dh = PyTorchDataHandling(lenna.shape)
+    x, y, tx, ty = dh.add_arrays('xw, yw, txw, tyw')
+
+    dh.cpu_arrays['xw'] = lenna
+    dh.cpu_arrays['txw'][...] = 0.7
+    dh.cpu_arrays['tyw'][...] = -0.7
+    dh.all_to_gpu()
+
+    kernel = pystencils_reco.AssignmentCollection({
+        y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_),
+                                        interpolation_mode='cubic_spline')
+    }).create_pytorch_op()().call
+
+    dh.run_kernel(kernel)
+    y_array = dh.gpu_arrays['yw']
+
+    dh = PyTorchDataHandling(lenna.shape)
+    x, y, tx, ty = dh.add_arrays('x, y, tx, ty')
+    dh.cpu_arrays['tx'] = torch.zeros(lenna.shape, requires_grad=True)
+    dh.cpu_arrays['ty'] = torch.zeros(lenna.shape, requires_grad=True)
+    dh.cpu_arrays['x'] = lenna
+    dh.all_to_gpu()
+
+    kernel = pystencils_reco.AssignmentCollection({
+        y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center +
+                                         pystencils.y_), interpolation_mode='cubic_spline')
+    }).create_pytorch_op(**dh.gpu_arrays)
+
+    print(kernel.ast)
+    kernel = kernel().call
+
+    learning_rate = 1e-4
+    params = (dh.cpu_arrays['tx'], dh.cpu_arrays['ty'])
+    # assert all([p.is_leaf for p in params])
+    optimizer = torch.optim.Adam(params, lr=learning_rate)
+
+    for i in range(100):
+        y = dh.run_kernel(kernel)
+        loss = (y - y_array).norm()
+
+        optimizer.zero_grad()
+
+        loss.backward()
+        assert y.requires_grad
+
+        optimizer.step()
+        print(loss.cpu().detach().numpy())
+        pyconrad.imshow(y)
+
+    pyconrad.imshow(dh.gpu_arrays)
+    pyconrad.imshow(dh.gpu_arrays)
-- 
GitLab