diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 09675140a29f716a2c157594492a48ace807b550..b7b6248bcb9e99eca0185e613331b1463fb67e94 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -10,6 +10,7 @@ from os.path import dirname, join import numpy as np +import pytest import skimage.io import sympy @@ -309,7 +310,8 @@ def test_get_shift(): pyconrad.imshow(dh.gpu_arrays) -def test_get_shift_tensors(): +@pytest.mark.parametrize('scalar_experiment', (False,)) +def test_get_shift_tensors(scalar_experiment): from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling import torch @@ -323,47 +325,102 @@ def test_get_shift_tensors(): dh.cpu_arrays['txw'][...] = 0.7 dh.cpu_arrays['tyw'][...] = -0.7 dh.all_to_gpu() + pyconrad.imshow(dh.gpu_arrays) kernel = pystencils_reco.AssignmentCollection({ - y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_), - interpolation_mode='cubic_spline') + y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + pystencils.y_)) }).create_pytorch_op()().call - dh.run_kernel(kernel) - y_array = dh.gpu_arrays['yw'] + y_array = dh.run_kernel(kernel) dh = PyTorchDataHandling(lenna.shape) x, y, tx, ty = dh.add_arrays('x, y, tx, ty') - dh.cpu_arrays['tx'] = torch.zeros(lenna.shape, requires_grad=True) - dh.cpu_arrays['ty'] = torch.zeros(lenna.shape, requires_grad=True) - dh.cpu_arrays['x'] = lenna - dh.all_to_gpu() - kernel = pystencils_reco.AssignmentCollection({ + if scalar_experiment: + var_x = torch.zeros((), requires_grad=True) + var_y = torch.zeros((), requires_grad=True) + else: + var_x = torch.zeros(lenna.shape, requires_grad=True) + var_y = torch.zeros(lenna.shape, requires_grad=True) + + dh.cpu_arrays.x = lenna + + assignments = pystencils_reco.AssignmentCollection({ y.center: x.interpolated_access((tx.center + pystencils.x_, 2 * ty.center + - pystencils.y_), interpolation_mode='cubic_spline') - }).create_pytorch_op(**dh.gpu_arrays) + pystencils.y_)) + }) + + print(pystencils.autodiff.create_backward_assignments(assignments)) + kernel = assignments.create_pytorch_op() print(kernel.ast) kernel = kernel().call - learning_rate = 1e-4 - params = (dh.cpu_arrays['tx'], dh.cpu_arrays['ty']) + learning_rate = 0.1 + params = (var_x, var_y) # assert all([p.is_leaf for p in params]) optimizer = torch.optim.Adam(params, lr=learning_rate) for i in range(100): + if scalar_experiment: + dh.cpu_arrays.tx = torch.ones(lenna.shape) * var_x + dh.cpu_arrays.ty = torch.ones(lenna.shape) * var_y + else: + dh.cpu_arrays.tx = var_x + dh.cpu_arrays.ty = var_y + dh.all_to_gpu() + y = dh.run_kernel(kernel) loss = (y - y_array).norm() optimizer.zero_grad() - loss.backward() + loss.backward(retain_graph=True) assert y.requires_grad optimizer.step() print(loss.cpu().detach().numpy()) - pyconrad.imshow(y) + print("var_x: " + str(var_x.mean())) + pyconrad.imshow(var_x) + # pyconrad.imshow(dh.gpu_arrays) + pyconrad.imshow(dh.gpu_arrays, wait_window_close=True) + + +@pytest.mark.parametrize('with_spline', ('with_spline', False)) +def test_spline_diff(with_spline): + from pystencils.fd import Diff + from pystencils.datahandling import SerialDataHandling + + lenna_file = join(dirname(__file__), "test_data", "lenna.png") + lenna = skimage.io.imread(lenna_file, as_gray=True).astype(np.float32) + + dh = SerialDataHandling(lenna.shape, default_target='gpu', default_ghost_layers=0, default_layout='numpy') + x, y, tx, ty = dh.add_arrays('x, y, tx, ty', dtype=np.float32) + + dh.cpu_arrays['x'] = lenna + dh.cpu_arrays['tx'][...] = 0.7 + dh.cpu_arrays['ty'][...] = -0.7 + out = dh.add_array('out', dtype=np.float32) + dh.all_to_gpu() + + kernel = pystencils_reco.AssignmentCollection({ + y.center: Diff(x, 0).interpolated_access((tx.center + pystencils.x_, + ty.center + pystencils.y_), + interpolation_mode='cubic_spline' if with_spline else 'linear') + }).compile(target='gpu') + + dh.run_kernel(kernel) + + print(pystencils.show_code(kernel)) + + kernel = pystencils_reco.AssignmentCollection({ + out.center: x.interpolated_access((tx.center + pystencils.x_, ty.center + pystencils.y_), + interpolation_mode='cubic_spline' if with_spline else 'linear') + }).compile(target='gpu') + + dh.run_kernel(kernel) + + print(pystencils.show_code(kernel)) pyconrad.imshow(dh.gpu_arrays) pyconrad.imshow(dh.gpu_arrays)