From bc8c41d3dbbd938f39080a85934367f5dc200ce9 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 13 Jan 2020 14:06:36 +0100
Subject: [PATCH] Avoid importing of tensorflow/torch for checking arrays

---
 src/pystencils_reco/_assignment_collection.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py
index 2f88230..82ff7c1 100644
--- a/src/pystencils_reco/_assignment_collection.py
+++ b/src/pystencils_reco/_assignment_collection.py
@@ -50,15 +50,17 @@ def get_type_of_arrays(*args):
     except Exception:
         pass
     try:
-        import torch
-        if any(isinstance(a, torch.Tensor) for a in args):
-            return NdArrayType.TORCH
+        if any('torch' in str(type(a)) for a in args):
+            import torch
+            if any(isinstance(a, torch.Tensor) for a in args):
+                return NdArrayType.TORCH
     except Exception as e:
         print(e)
     try:
-        from tensorflow import Tensor
-        if any(isinstance(a, Tensor) for a in args):
-            return NdArrayType.TENSORFLOW
+        if any('tensorflow' in str(type(a)) for a in args):
+            from tensorflow import Tensor
+            if any(isinstance(a, Tensor) for a in args):
+                return NdArrayType.TENSORFLOW
     except Exception:
         pass
     try:
-- 
GitLab