From bc8c41d3dbbd938f39080a85934367f5dc200ce9 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 13 Jan 2020 14:06:36 +0100 Subject: [PATCH] Avoid importing of tensorflow/torch for checking arrays --- src/pystencils_reco/_assignment_collection.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py index 2f88230..82ff7c1 100644 --- a/src/pystencils_reco/_assignment_collection.py +++ b/src/pystencils_reco/_assignment_collection.py @@ -50,15 +50,17 @@ def get_type_of_arrays(*args): except Exception: pass try: - import torch - if any(isinstance(a, torch.Tensor) for a in args): - return NdArrayType.TORCH + if any('torch' in str(type(a)) for a in args): + import torch + if any(isinstance(a, torch.Tensor) for a in args): + return NdArrayType.TORCH except Exception as e: print(e) try: - from tensorflow import Tensor - if any(isinstance(a, Tensor) for a in args): - return NdArrayType.TENSORFLOW + if any('tensorflow' in str(type(a)) for a in args): + from tensorflow import Tensor + if any(isinstance(a, Tensor) for a in args): + return NdArrayType.TENSORFLOW except Exception: pass try: -- GitLab