From 110391c04b3d56e761febd9e1924c7d6f1b12ea3 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 7 Aug 2019 16:23:18 +0200 Subject: [PATCH] Avoid deprecation warning by use of py_func --- src/pystencils_autodiff/backends/_tensorflow.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index e0e6f0b..b243437 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -62,12 +62,12 @@ def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, return [rtn_dict[o.name] for o in autodiffop._backward_output_fields] def backward(op, *grad): - return tf.py_func(helper_backward, - [*op.inputs, + return py_func(helper_backward, + [*op.inputs, *grad], - [f.dtype.numpy_dtype for f in autodiffop._backward_output_fields], - name=autodiffop.op_name + '_backward', - stateful=False) + [f.dtype.numpy_dtype for f in autodiffop._backward_output_fields], + name=autodiffop.op_name + '_backward', + stateful=False) output_tensors = _py_func(helper_forward, [inputfield_tensor_dict[f] -- GitLab