diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py
index 94977b21e0b8839604168960d04ead9a1cd3642c..5286743989f1700c7c9879dde98dad2642b98d89 100644
--- a/src/pystencils_autodiff/backends/_tensorflow.py
+++ b/src/pystencils_autodiff/backends/_tensorflow.py
@@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi
 
     backward_func = getattr(compiled_op, stringcase.snakecase(
         stringcase.pascalcase("call_" + backward_ast.function_name)))
+    grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
 
-    def gradient_calculation(op, grad):
-        if isinstance(grad, Iterable):
+    def gradient_calculation(op, *grad):
+        if not isinstance(grad, Iterable):
             grad = [grad]
-        return backward_func(**{autodiff_obj.backward_input_fields[i].name: g for i, g in enumerate(grad)},
+
+        return backward_func(**{grad_fields[i].name: g for i, g in enumerate(grad)},
                              **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs)
                                 if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed})