diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index 4ecae0f84efb01ef86022f2aab6bc8ed986754f5..e0e6f0b98ece00ed00b29a2ecdf5859bb5094b4d 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -1,7 +1,7 @@ import tensorflow as tf +from tensorflow.compat.v1 import get_default_graph, py_func import pystencils_autodiff -from tensorflow.compat.v1 import get_default_graph _num_generated_ops = 0 @@ -33,7 +33,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): # Add gradient override map with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}): - return tf.py_func(func, inp, Tout, stateful=stateful, name=name) + return py_func(func, inp, Tout, stateful=stateful, name=name) def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp,