diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py
index 5f1f073506a4293102b850ca9b3387c44a45e07b..258f34dc1a01c5ea3180e550f0c4896f5e231eff 100644
--- a/tests/test_superresolution.py
+++ b/tests/test_superresolution.py
@@ -67,7 +67,18 @@ def test_torch_simple():
     x = torch.ones((10, 40)).cuda()
     h = torch.ones((10, 40, 8)).cuda()
 
-    kernel().forward(h, x)
+    y = kernel().forward(h, x)
+
+    # with autograd
+    x = torch.ones((10, 40), requires_grad=True).cuda()
+    h = torch.ones((10, 40, 8), requires_grad=True).cuda()
+
+    y = kernel().forward(h, x)[0]
+
+    assert y.requires_grad
+
+    loss = y.mean()
+    loss.backward()
     # kernel().forward(*([1]*9), x, y)