diff --git a/src/pystencils_autodiff/backends/_tensorflow.py b/src/pystencils_autodiff/backends/_tensorflow.py index a7556c4f9d41e028fe73205918eb13d56cf546b1..6e027200194c511644136f0e9a560356e4c4682f 100644 --- a/src/pystencils_autodiff/backends/_tensorflow.py +++ b/src/pystencils_autodiff/backends/_tensorflow.py @@ -1,7 +1,6 @@ from collections.abc import Iterable import stringcase -import tensorflow as tf from tensorflow.compat.v1 import get_default_graph, py_func import pystencils_autodiff @@ -25,6 +24,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): :param grad: Custom Gradient Function :return: """ + import tensorflow as tf # Generate Random Gradient name in order to avoid conflicts with inbuilt names global _num_generated_ops rnd_name = 'PyFuncGrad' + str(_num_generated_ops) + 'ABC@a1b2c3' @@ -44,6 +44,7 @@ def _py_func(func, inp, Tout, stateful=False, name=None, grad=None): def native_tensorflowop_from_autodiffop(autodiff_obj: pystencils_autodiff.AutoDiffOp, use_cuda): + import tensorflow as tf if use_cuda: forward_ast = autodiff_obj.forward_ast_gpu backward_ast = autodiff_obj.backward_ast_gpu