Skip to content
Snippets Groups Projects
Commit cd642eb1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Remove deprecation warning (again :fries:)

parent ca171f5b
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf import tensorflow as tf
import pystencils_autodiff import pystencils_autodiff
import numpy as np from tf.compat.v1 import get_default_graph
from pystencils.utils import DotDict
_num_generated_ops = 0 _num_generated_ops = 0
def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): def _py_func(func, inp, Tout, stateful=False, name=None, grad=None):
""" """
Copied from random internet forum. It seems to be important to give Copied from random internet forum. It seems to be important to give
PyFunc to give an random name in override map to properly register gradients PyFunc to give an random name in override map to properly register gradients
PyFunc defined as given by Tensorflow PyFunc defined as given by Tensorflow
...@@ -29,14 +29,17 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): ...@@ -29,14 +29,17 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None):
tf.RegisterGradient(rnd_name)(grad) tf.RegisterGradient(rnd_name)(grad)
# Get current graph # Get current graph
g = tf.get_default_graph() g = get_default_graph()
# Add gradient override map # Add gradient override map
with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}): with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name) return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inputfield_tensor_dict, forward_function, backward_function): def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp,
inputfield_tensor_dict,
forward_function,
backward_function):
def helper_forward(*args): def helper_forward(*args):
kwargs = dict() kwargs = dict()
...@@ -59,7 +62,12 @@ def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inp ...@@ -59,7 +62,12 @@ def tensorflowop_from_autodiffop(autodiffop: pystencils_autodiff.AutoDiffOp, inp
return [rtn_dict[o.name] for o in autodiffop._backward_output_fields] return [rtn_dict[o.name] for o in autodiffop._backward_output_fields]
def backward(op, *grad): def backward(op, *grad):
return tf.py_func(helper_backward, [*op.inputs, *grad], [f.dtype.numpy_dtype for f in autodiffop._backward_output_fields], name=autodiffop.op_name + '_backward', stateful=False) return tf.py_func(helper_backward,
[*op.inputs,
*grad],
[f.dtype.numpy_dtype for f in autodiffop._backward_output_fields],
name=autodiffop.op_name + '_backward',
stateful=False)
output_tensors = _py_func(helper_forward, output_tensors = _py_func(helper_forward,
[inputfield_tensor_dict[f] [inputfield_tensor_dict[f]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment