From dff37a2adbd6f990d817ec2261efba36c54fa63e Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 13 Jan 2020 16:22:24 +0100 Subject: [PATCH] Avoid global import of tensorflow --- src/pystencils_autodiff/backends/_tensorflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index a7556c4..6e02720 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -1,7 +1,6 @@ from collections.abc import Iterable import stringcase -import tensorflow as tf from tensorflow.compat.v1 import get_default_graph, py_func import pystencils_autodiff @@ -25,6 +24,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): :param grad: Custom Gradient Function :return: """ + import tensorflow as tf # Generate Random Gradient name in order to avoid conflicts with inbuilt names global _num_generated_ops rnd_name = 'PyFuncGrad' + str(_num_generated_ops) + 'ABC@a1b2c3' @@ -44,6 +44,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDiffOp, use_cuda): + import tensorflow as tf if use_cuda: forward_ast = autodiff_obj.forward_ast_gpu backward_ast = autodiff_obj.backward_ast_gpu -- GitLab