From 8b80a78e6840b78fada0bb192e3615a907f50540 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 16 Dec 2019 12:05:48 +0100 Subject: [PATCH] Add additional assertion: with_cuda needs to be bool --- src/pystencils_autodiff/backends/_torch_native.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index d1dda49..75f79bf 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -84,8 +84,9 @@ def create_autograd_function(autodiff_obj, use_cuda): assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) for i, f in enumerate(grad_fields)) - assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. " - f"Op was compiled for CUDA: {str(use_cuda)}" + assert use_cuda in (True, False), "use_cuda needs to be True or False" + assert all(a.is_cuda == use_cuda for a in grad_outputs), ( + "Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}") for field in autodiff_obj.backward_output_fields: backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape, -- GitLab