From 3b21e27ab6c2b40c6ec0e6467a4fa27e9cdbd7f8 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 29 Nov 2019 17:43:25 +0100
Subject: [PATCH] Fix gradient calculation for Tensorflow

---
 src/pystencils_autodiff/backends/_tensorflow.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py
index 94977b2..5286743 100644
--- a/src/pystencils_autodiff/backends/_tensorflow.py
+++ b/src/pystencils_autodiff/backends/_tensorflow.py
@@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi
 
     backward_func = getattr(compiled_op, stringcase.snakecase(
         stringcase.pascalcase("call_" + backward_ast.function_name)))
+    grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
 
-    def gradient_calculation(op, grad):
-        if isinstance(grad, Iterable):
+    def gradient_calculation(op, *grad):
+        if not isinstance(grad, Iterable):
             grad = [grad]
-        return backward_func(**{autodiff_obj.backward_input_fields[i].name: g for i, g in enumerate(grad)},
+
+        return backward_func(**{grad_fields[i].name: g for i, g in enumerate(grad)},
                              **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs)
                                 if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed})
 
-- 
GitLab