Skip to content
Snippets Groups Projects
Commit 8b80a78e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add additional assertion: with_cuda needs to be bool

parent c2cd47b7
No related branches found
No related tags found
No related merge requests found
Pipeline #20535 failed
...@@ -84,8 +84,9 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -84,8 +84,9 @@ def create_autograd_function(autodiff_obj, use_cuda):
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields)) assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(grad_fields)) for i, f in enumerate(grad_fields))
assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. " assert use_cuda in (True, False), "use_cuda needs to be True or False"
f"Op was compiled for CUDA: {str(use_cuda)}" assert all(a.is_cuda == use_cuda for a in grad_outputs), (
"Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}")
for field in autodiff_obj.backward_output_fields: for field in autodiff_obj.backward_output_fields:
backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape, backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment