diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index d1dda49502992c8f102e14b6d9c8b2571f8cd7b7..75f79bf7bb9d5541f4c7faa01ca41ec9f965a452 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -84,8 +84,9 @@ def create_autograd_function(autodiff_obj, use_cuda):
         assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
         assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
                    for i, f in enumerate(grad_fields))
-        assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
-        f"Op was compiled for CUDA: {str(use_cuda)}"
+        assert use_cuda in (True, False), "use_cuda needs to be True or False"
+        assert all(a.is_cuda == use_cuda for a in grad_outputs), (
+            "Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}")
 
         for field in autodiff_obj.backward_output_fields:
             backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape,