From af7bfff213b1ed152186440af5ebc7950e7d7035 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 14 Jan 2020 18:41:54 +0100 Subject: [PATCH] Refactor interpolation --- src/pystencils_reco/_assignment_collection.py | 3 + tests/test_superresolution.py | 83 ++++++++++++++++++- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py index 82ff7c1..e236e03 100644 --- a/src/pystencils_reco/_assignment_collection.py +++ b/src/pystencils_reco/_assignment_collection.py @@ -181,6 +181,9 @@ class AssignmentCollection(pystencils.AssignmentCollection): if hasattr(t, 'requires_grad') and not t.requires_grad] constant_fields = {f for f in self.free_fields if f.name in constant_field_names} + for n in [f for f, t in kwargs.items() if hasattr(t, 'requires_grad')]: + kwargs.pop(n) + if not self._autodiff: if hasattr(self, '_create_autodiff'): self._create_autodiff(constant_fields, **kwargs) diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 473c58c..0967514 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -259,11 +259,29 @@ def test_spatial_derivative(): tx, ty = pystencils.fields('t_x, t_y: float32[2d]') assignments = pystencils.AssignmentCollection({ - y.center: x.interpolated_access((tx.center, 2 * ty.center)) + y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_)) }) backward_assignments = pystencils.autodiff.create_backward_assignments(assignments) + print("assignments: " + str(assignments)) + print("backward_assignments: " + str(backward_assignments)) + + +def test_spatial_derivative2(): + import pystencils.interpolation_astnodes + x, y = pystencils.fields('x, y: float32[2d]') + tx, ty = pystencils.fields('t_x, t_y: float32[2d]') + + assignments = pystencils.AssignmentCollection({ + y.center: x.interpolated_access((tx.center + pystencils.x_, ty.center + 2 * pystencils.y_)) + }) + + backward_assignments = pystencils.autodiff.create_backward_assignments(assignments) + + assert backward_assignments.atoms(pystencils.interpolation_astnodes.DiffInterpolatorAccess) + + print("assignments: " + str(assignments)) print("backward_assignments: " + str(backward_assignments)) @@ -282,9 +300,70 @@ def test_get_shift(): dh.all_to_gpu() kernel = pystencils_reco.AssignmentCollection({ - y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_)) + y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + + pystencils.y_), interpolation_mode='cubic_spline') }).create_pytorch_op()().forward dh.run_kernel(kernel) pyconrad.imshow(dh.gpu_arrays) + + +def test_get_shift_tensors(): + from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling + import torch + + lenna_file = join(dirname(__file__), "test_data", "lenna.png") + lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32) + + dh = PyTorchDataHandling(lenna.shape) + x, y, tx, ty = dh.add_arrays('xw, yw, txw, tyw') + + dh.cpu_arrays['xw'] = lenna + dh.cpu_arrays['txw'][...] = 0.7 + dh.cpu_arrays['tyw'][...] = -0.7 + dh.all_to_gpu() + + kernel = pystencils_reco.AssignmentCollection({ + y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_), + interpolation_mode='cubic_spline') + }).create_pytorch_op()().call + + dh.run_kernel(kernel) + y_array = dh.gpu_arrays['yw'] + + dh = PyTorchDataHandling(lenna.shape) + x, y, tx, ty = dh.add_arrays('x, y, tx, ty') + dh.cpu_arrays['tx'] = torch.zeros(lenna.shape, requires_grad=True) + dh.cpu_arrays['ty'] = torch.zeros(lenna.shape, requires_grad=True) + dh.cpu_arrays['x'] = lenna + dh.all_to_gpu() + + kernel = pystencils_reco.AssignmentCollection({ + y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + + pystencils.y_), interpolation_mode='cubic_spline') + }).create_pytorch_op(**dh.gpu_arrays) + + print(kernel.ast) + kernel = kernel().call + + learning_rate = 1e-4 + params = (dh.cpu_arrays['tx'], dh.cpu_arrays['ty']) + # assert all([p.is_leaf for p in params]) + optimizer = torch.optim.Adam(params, lr=learning_rate) + + for i in range(100): + y = dh.run_kernel(kernel) + loss = (y - y_array).norm() + + optimizer.zero_grad() + + loss.backward() + assert y.requires_grad + + optimizer.step() + print(loss.cpu().detach().numpy()) + pyconrad.imshow(y) + + pyconrad.imshow(dh.gpu_arrays) + pyconrad.imshow(dh.gpu_arrays) -- GitLab