Skip to content
Snippets Groups Projects
Commit cd642eb1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Remove deprecation warning (again :fries:)

parent ca171f5b
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
import pystencils_autodiff
import numpy as np
from pystencils.utils import DotDict
from tf.compat.v1 import get_default_graph
_num_generated_ops = 0
def _py_func(func, inp, Tout, stateful=False, name=None, grad=None):
"""
Copied from random internet forum. It seems to be important to give
Copied from random internet forum. It seems to be important to give
PyFunc to give an random name in override map to properly register gradients
PyFunc defined as given by Tensorflow
......@@ -29,14 +29,17 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None):
tf.RegisterGradient(rnd_name)(grad)
# Get current graph
g = tf.get_default_graph()
g = get_default_graph()
# Add gradient override map
with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inputfield_tensor_dict, forward_function, backward_function):
def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp,
inputfield_tensor_dict,
forward_function,
backward_function):
def helper_forward(*args):
kwargs = dict()
......@@ -59,7 +62,12 @@ def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inp
return [rtn_dict[o.name] for o in autodiffop._backward_output_fields]
def backward(op, *grad):
return tf.py_func(helper_backward, [*op.inputs, *grad], [f.dtype.numpy_dtype for f in autodiffop._backward_output_fields], name=autodiffop.op_name + '_backward', stateful=False)
return tf.py_func(helper_backward,
[*op.inputs,
*grad],
[f.dtype.numpy_dtype for f in autodiffop._backward_output_fields],
name=autodiffop.op_name + '_backward',
stateful=False)
output_tensors = _py_func(helper_forward,
[inputfield_tensor_dict[f]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment