diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 75f79bf7bb9d5541f4c7faa01ca41ec9f965a452..9fae28c5f8ebf28bca5f5f95e4c10d04b20edd9b 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -74,9 +74,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
         return tuple(output_tensors.values())
 
     def backward(self, *grad_outputs):
+        nonlocal use_cuda
         if use_cuda:
+            use_cuda = True
             grad_outputs = [a.contiguous().cuda() for a in grad_outputs]
         else:
+            use_cuda = False
             grad_outputs = [a.contiguous().cpu() for a in grad_outputs]
 
         grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
@@ -84,7 +87,6 @@ def create_autograd_function(autodiff_obj, use_cuda):
         assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
         assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
                    for i, f in enumerate(grad_fields))
-        assert use_cuda in (True, False), "use_cuda needs to be True or False"
         assert all(a.is_cuda == use_cuda for a in grad_outputs), (
             "Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}")