Skip to content
Snippets Groups Projects
Commit dff37a2a authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Avoid global import of tensorflow

parent 0cba0d6e
Branches
Tags
No related merge requests found
from collections.abc import Iterable from collections.abc import Iterable
import stringcase import stringcase
import tensorflow as tf
from tensorflow.compat.v1 import get_default_graph, py_func from tensorflow.compat.v1 import get_default_graph, py_func
import pystencils_autodiff import pystencils_autodiff
...@@ -25,6 +24,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): ...@@ -25,6 +24,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None):
:param grad: Custom Gradient Function :param grad: Custom Gradient Function
:return: :return:
""" """
import tensorflow as tf
# Generate Random Gradient name in order to avoid conflicts with inbuilt names # Generate Random Gradient name in order to avoid conflicts with inbuilt names
global _num_generated_ops global _num_generated_ops
rnd_name = 'PyFuncGrad' + str(_num_generated_ops) + 'ABC@a1b2c3' rnd_name = 'PyFuncGrad' + str(_num_generated_ops) + 'ABC@a1b2c3'
...@@ -44,6 +44,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): ...@@ -44,6 +44,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None):
def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDiffOp, def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDiffOp,
use_cuda): use_cuda):
import tensorflow as tf
if use_cuda: if use_cuda:
forward_ast = autodiff_obj.forward_ast_gpu forward_ast = autodiff_obj.forward_ast_gpu
backward_ast = autodiff_obj.backward_ast_gpu backward_ast = autodiff_obj.backward_ast_gpu
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment