From 1e141dfb67e859bbac1144cec2563540fc1ff81d Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 29 Nov 2019 19:10:11 +0100
Subject: [PATCH] Change torch native for new interface

---
 src/pystencils_autodiff/backends/_torch_native.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 48bc4c2..d1dda49 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -78,10 +78,12 @@ def create_autograd_function(autodiff_obj, use_cuda):
             grad_outputs = [a.contiguous().cuda() for a in grad_outputs]
         else:
             grad_outputs = [a.contiguous().cpu() for a in grad_outputs]
-        gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
-        assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
+
+        grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
+        gradients = {f.name: grad_outputs[i] for i, f in enumerate(grad_fields)}
+        assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields))
         assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
-                   for i, f in enumerate(autodiff_obj.backward_input_fields))
+                   for i, f in enumerate(grad_fields))
         assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
         f"Op was compiled for CUDA: {str(use_cuda)}"
 
-- 
GitLab