diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 5f1f073506a4293102b850ca9b3387c44a45e07b..258f34dc1a01c5ea3180e550f0c4896f5e231eff 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -67,7 +67,18 @@ def test_torch_simple(): x = torch.ones((10, 40)).cuda() h = torch.ones((10, 40, 8)).cuda() - kernel().forward(h, x) + y = kernel().forward(h, x) + + # with autograd + x = torch.ones((10, 40), requires_grad=True).cuda() + h = torch.ones((10, 40, 8), requires_grad=True).cuda() + + y = kernel().forward(h, x)[0] + + assert y.requires_grad + + loss = y.mean() + loss.backward() # kernel().forward(*([1]*9), x, y)