From 585ac16b74752049f3added51e236f62c66dd51a Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 28 Feb 2020 17:51:27 +0100 Subject: [PATCH] Call super resultion kernels --- tests/test_superresolution.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 5f1f073..258f34d 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -67,7 +67,18 @@ def test_torch_simple(): x = torch.ones((10, 40)).cuda() h = torch.ones((10, 40, 8)).cuda() - kernel().forward(h, x) + y = kernel().forward(h, x) + + # with autograd + x = torch.ones((10, 40), requires_grad=True).cuda() + h = torch.ones((10, 40, 8), requires_grad=True).cuda() + + y = kernel().forward(h, x)[0] + + assert y.requires_grad + + loss = y.mean() + loss.backward() # kernel().forward(*([1]*9), x, y) -- GitLab