diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 16c7d2f71ec6623cf29a525fe42f6dca61ebabb9..a879abb91111da927632b64840c9f8b854075774 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -426,3 +426,65 @@ def test_spline_diff(with_spline): pyconrad.imshow(dh.gpu_arrays) pyconrad.imshow(dh.gpu_arrays) + + +@pytest.mark.parametrize('scalar_experiment', (False,)) +def test_rotation(scalar_experiment): + from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling + from pystencils_reco.resampling import rotation_transform + + import torch + + lenna_file = join(dirname(__file__), "test_data", "lenna.png") + lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32) + + GROUNDTRUTH_ANGLE = 0.3 + + target = np.zeros(lenna.shape) + rotation_transform(lenna, target, GROUNDTRUTH_ANGLE)() + target = torch.Tensor(target).cuda() + + dh = PyTorchDataHandling(lenna.shape) + x, y, angle = dh.add_arrays('x, y, angle') + + if scalar_experiment: + var_angle = torch.zeros((), requires_grad=True) + else: + var_angle = torch.zeros(lenna.shape, requires_grad=True) + + var_lenna = torch.autograd.Variable(torch.from_numpy( + lenna + np.random.randn(*lenna.shape).astype(np.float32)), requires_grad=True) + assert var_lenna.requires_grad + + learning_rate = 0.1 + params = (var_angle, var_lenna) + + optimizer = torch.optim.Adam(params, lr=learning_rate) + + assignments = rotation_transform(x, y, angle) + kernel = assignments.create_pytorch_op() + print(kernel) + + kernel = kernel().call + + for i in range(100000): + if scalar_experiment: + dh.cpu_arrays.angle = torch.ones(lenna.shape) * (var_angle + 0.29) + else: + dh.cpu_arrays.angle = var_angle + dh.cpu_arrays.x = var_lenna + dh.all_to_gpu() + + y = dh.run_kernel(kernel) + loss = (y - target).norm() + + optimizer.zero_grad() + + loss.backward(retain_graph=True) + assert y.requires_grad + + optimizer.step() + print(loss.cpu().detach().numpy()) + pyconrad.imshow(var_lenna) + + pyconrad.show_everything()