Skip to content
Snippets Groups Projects
Commit 3238be05 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Do not fail when trying to register a gradient a second time

parent 60a0b38f
No related branches found
No related tags found
No related merge requests found
Pipeline #18167 failed
...@@ -69,9 +69,12 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi ...@@ -69,9 +69,12 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi
**{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs) **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs)
if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed}) if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed})
tf.RegisterGradient(stringcase.pascalcase("call_" + forward_ast.function_name))( try:
gradient_calculation tf.RegisterGradient(stringcase.pascalcase("call_" + forward_ast.function_name))(
) gradient_calculation
)
except Exception:
pass
return getattr(compiled_op, stringcase.snakecase(stringcase.pascalcase("call_" + forward_ast.function_name))) return getattr(compiled_op, stringcase.snakecase(stringcase.pascalcase("call_" + forward_ast.function_name)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment