diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 48bc4c2047e6fe745fa21843e165758adafa2fa7..d1dda49502992c8f102e14b6d9c8b2571f8cd7b7 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda): grad_outputs = [a.contiguous().cuda() for a in grad_outputs] else: grad_outputs = [a.contiguous().cpu() for a in grad_outputs] - gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)} - assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields)) + + grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields] + gradients = {f.name: grad_outputs[i] for i, f in enumerate(grad_fields)} + assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) - for i, f in enumerate(autodiff_obj.backward_input_fields)) + for i, f in enumerate(grad_fields)) assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}"