Skip to content
Snippets Groups Projects
Commit 585ac16b authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Call super resultion kernels

parent 9a53b949
No related merge requests found
Pipeline #22366 failed with stage
in 1 minute and 18 seconds
...@@ -67,7 +67,18 @@ def test_torch_simple(): ...@@ -67,7 +67,18 @@ def test_torch_simple():
x = torch.ones((10, 40)).cuda() x = torch.ones((10, 40)).cuda()
h = torch.ones((10, 40, 8)).cuda() h = torch.ones((10, 40, 8)).cuda()
kernel().forward(h, x) y = kernel().forward(h, x)
# with autograd
x = torch.ones((10, 40), requires_grad=True).cuda()
h = torch.ones((10, 40, 8), requires_grad=True).cuda()
y = kernel().forward(h, x)[0]
assert y.requires_grad
loss = y.mean()
loss.backward()
# kernel().forward(*([1]*9), x, y) # kernel().forward(*([1]*9), x, y)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment