diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index 94977b21e0b8839604168960d04ead9a1cd3642c..5286743989f1700c7c9879dde98dad2642b98d89 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi backward_func = getattr(compiled_op, stringcase.snakecase( stringcase.pascalcase("call_" + backward_ast.function_name))) + grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields] - def gradient_calculation(op, grad): - if isinstance(grad, Iterable): + def gradient_calculation(op, *grad): + if not isinstance(grad, Iterable): grad = [grad] - return backward_func(**{autodiff_obj.backward_input_fields[i].name: g for i, g in enumerate(grad)}, + + return backward_func(**{grad_fields[i].name: g for i, g in enumerate(grad)}, **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs) if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed})