From 3b21e27ab6c2b40c6ec0e6467a4fa27e9cdbd7f8 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 29 Nov 2019 17:43:25 +0100 Subject: [PATCH] Fix gradient calculation for Tensorflow --- src/pystencils_autodiff/backends/_tensorflow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index 94977b2..5286743 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi backward_func = getattr(compiled_op, stringcase.snakecase( stringcase.pascalcase("call_" + backward_ast.function_name))) + grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields] - def gradient_calculation(op, grad): - if isinstance(grad, Iterable): + def gradient_calculation(op, *grad): + if not isinstance(grad, Iterable): grad = [grad] - return backward_func(**{autodiff_obj.backward_input_fields[i].name: g for i, g in enumerate(grad)}, + + return backward_func(**{grad_fields[i].name: g for i, g in enumerate(grad)}, **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs) if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed}) -- GitLab