diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py
index 2f8823083b31476f7e992a65652f14d637cfa909..82ff7c17a7ad0f6f222c62aa0356609bce46e91c 100644
--- a/src/pystencils_reco/_assignment_collection.py
+++ b/src/pystencils_reco/_assignment_collection.py
@@ -50,15 +50,17 @@ def get_type_of_arrays(*args):
     except Exception:
         pass
     try:
-        import torch
-        if any(isinstance(a, torch.Tensor) for a in args):
-            return NdArrayType.TORCH
+        if any('torch' in str(type(a)) for a in args):
+            import torch
+            if any(isinstance(a, torch.Tensor) for a in args):
+                return NdArrayType.TORCH
     except Exception as e:
         print(e)
     try:
-        from tensorflow import Tensor
-        if any(isinstance(a, Tensor) for a in args):
-            return NdArrayType.TENSORFLOW
+        if any('tensorflow' in str(type(a)) for a in args):
+            from tensorflow import Tensor
+            if any(isinstance(a, Tensor) for a in args):
+                return NdArrayType.TENSORFLOW
     except Exception:
         pass
     try: