Skip to content
Snippets Groups Projects
Commit 1e141dfb authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Change torch native for new interface

parent 75dd38f2
No related branches found
No related tags found
No related merge requests found
Pipeline #20141 failed
...@@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
grad_outputs = [a.contiguous().cuda() for a in grad_outputs] grad_outputs = [a.contiguous().cuda() for a in grad_outputs]
else: else:
grad_outputs = [a.contiguous().cpu() for a in grad_outputs] grad_outputs = [a.contiguous().cpu() for a in grad_outputs]
gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields)) grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
gradients = {f.name: grad_outputs[i] for i, f in enumerate(grad_fields)}
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(autodiff_obj.backward_input_fields)) for i, f in enumerate(grad_fields))
assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. " assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
f"Op was compiled for CUDA: {str(use_cuda)}" f"Op was compiled for CUDA: {str(use_cuda)}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment