Skip to content
Snippets Groups Projects
Commit 3b21e27a authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix gradient calculation for Tensorflow

parent 9ed3556c
Branches
Tags
No related merge requests found
......@@ -61,11 +61,13 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi
backward_func = getattr(compiled_op, stringcase.snakecase(
stringcase.pascalcase("call_" + backward_ast.function_name)))
grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields]
def gradient_calculation(op, grad):
if isinstance(grad, Iterable):
def gradient_calculation(op, *grad):
if not isinstance(grad, Iterable):
grad = [grad]
return backward_func(**{autodiff_obj.backward_input_fields[i].name: g for i, g in enumerate(grad)},
return backward_func(**{grad_fields[i].name: g for i, g in enumerate(grad)},
**{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs)
if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment