diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index cd82bfd0a8eed2d1ee2d19c6b576f995ff4ee8c4..db85df108fcc0907aea70e59e134046cc023dad3 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -69,9 +69,12 @@ def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDi **{autodiff_obj.forward_input_fields[i].name: inp for i, inp in enumerate(op.inputs) if autodiff_obj.forward_input_fields[i] in backward_ast.fields_accessed}) - tf.RegisterGradient(stringcase.pascalcase("call_" + forward_ast.function_name))( - gradient_calculation - ) + try: + tf.RegisterGradient(stringcase.pascalcase("call_" + forward_ast.function_name))( + gradient_calculation + ) + except Exception: + pass return getattr(compiled_op, stringcase.snakecase(stringcase.pascalcase("call_" + forward_ast.function_name)))