Skip to content
Snippets Groups Projects
Commit 25ebe582 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow use_cuda to be truthy (instead of only True/False)

parent f3772f74
Branches
Tags
No related merge requests found
Pipeline #20540 failed
...@@ -74,9 +74,12 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -74,9 +74,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
return tuple(output_tensors.values()) return tuple(output_tensors.values())
def backward(self, *grad_outputs): def backward(self, *grad_outputs):
nonlocal use_cuda
if use_cuda: if use_cuda:
use_cuda = True
grad_outputs = [a.contiguous().cuda() for a in grad_outputs] grad_outputs = [a.contiguous().cuda() for a in grad_outputs]
else: else:
use_cuda = False
grad_outputs = [a.contiguous().cpu() for a in grad_outputs] grad_outputs = [a.contiguous().cpu() for a in grad_outputs]
grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields] grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
...@@ -84,7 +87,6 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -84,7 +87,6 @@ def create_autograd_function(autodiff_obj, use_cuda):
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields)) assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(grad_fields)) for i, f in enumerate(grad_fields))
assert use_cuda in (True, False), "use_cuda needs to be True or False"
assert all(a.is_cuda == use_cuda for a in grad_outputs), ( assert all(a.is_cuda == use_cuda for a in grad_outputs), (
"Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}") "Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment