diff --git a/src/pystencils_reco/_assignment_collection.py b/src/pystencils_reco/_assignment_collection.py index 2f8823083b31476f7e992a65652f14d637cfa909..82ff7c17a7ad0f6f222c62aa0356609bce46e91c 100644 --- a/src/pystencils_reco/_assignment_collection.py +++ b/src/pystencils_reco/_assignment_collection.py @@ -50,15 +50,17 @@ def get_type_of_arrays(*args): except Exception: pass try: - import torch - if any(isinstance(a, torch.Tensor) for a in args): - return NdArrayType.TORCH + if any('torch' in str(type(a)) for a in args): + import torch + if any(isinstance(a, torch.Tensor) for a in args): + return NdArrayType.TORCH except Exception as e: print(e) try: - from tensorflow import Tensor - if any(isinstance(a, Tensor) for a in args): - return NdArrayType.TENSORFLOW + if any('tensorflow' in str(type(a)) for a in args): + from tensorflow import Tensor + if any(isinstance(a, Tensor) for a in args): + return NdArrayType.TENSORFLOW except Exception: pass try: