Skip to content
Snippets Groups Projects
Commit 1e141dfb authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Change torch native for new interface

parent 75dd38f2
No related merge requests found
Pipeline #20141 failed with stage
in 1 minute and 39 seconds
......@@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
grad_outputs = [a.contiguous().cuda() for a in grad_outputs]
else:
grad_outputs = [a.contiguous().cpu() for a in grad_outputs]
gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
gradients = {f.name: grad_outputs[i] for i, f in enumerate(grad_fields)}
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(autodiff_obj.backward_input_fields))
for i, f in enumerate(grad_fields))
assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
f"Op was compiled for CUDA: {str(use_cuda)}"
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment