diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index 1b7b3b7b6b4da18bd44b1d0c64e608f79c7772af..58eb2c8daf4c332b395f3f82c8cbaf6721da6cde 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -1,14 +1,14 @@ import tensorflow as tf + import pystencils_autodiff -import numpy as np -from pystencils.utils import DotDict +from tf.compat.v1 import get_default_graph _num_generated_ops = 0 def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): """ - Copied from random internet forum. It seems to be important to give + Copied from random internet forum. It seems to be important to give PyFunc to give an random name in override map to properly register gradients PyFunc defined as given by Tensorflow @@ -29,14 +29,17 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): tf.RegisterGradient(rnd_name)(grad) # Get current graph - g = tf.get_default_graph() + g = get_default_graph() # Add gradient override map with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name) -def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inputfield_tensor_dict, forward_function, backward_function): +def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, + inputfield_tensor_dict, + forward_function, + backward_function): def helper_forward(*args): kwargs = dict() @@ -59,7 +62,12 @@ def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inp return [rtn_dict[o.name] for o in autodiffop._backward_output_fields] def backward(op, *grad): - return tf.py_func(helper_backward, [*op.inputs, *grad], [f.dtype.numpy_dtype for f in autodiffop._backward_output_fields], name=autodiffop.op_name + '_backward', stateful=False) + return tf.py_func(helper_backward, + [*op.inputs, + *grad], + [f.dtype.numpy_dtype for f in autodiffop._backward_output_fields], + name=autodiffop.op_name + '_backward', + stateful=False) output_tensors = _py_func(helper_forward, [inputfield_tensor_dict[f]