Skip to content
Snippets Groups Projects
Commit bc8c41d3 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Avoid importing of tensorflow/torch for checking arrays

parent 053a14ff
No related branches found
No related tags found
No related merge requests found
...@@ -50,15 +50,17 @@ def get_type_of_arrays(*args): ...@@ -50,15 +50,17 @@ def get_type_of_arrays(*args):
except Exception: except Exception:
pass pass
try: try:
import torch if any('torch' in str(type(a)) for a in args):
if any(isinstance(a, torch.Tensor) for a in args): import torch
return NdArrayType.TORCH if any(isinstance(a, torch.Tensor) for a in args):
return NdArrayType.TORCH
except Exception as e: except Exception as e:
print(e) print(e)
try: try:
from tensorflow import Tensor if any('tensorflow' in str(type(a)) for a in args):
if any(isinstance(a, Tensor) for a in args): from tensorflow import Tensor
return NdArrayType.TENSORFLOW if any(isinstance(a, Tensor) for a in args):
return NdArrayType.TENSORFLOW
except Exception: except Exception:
pass pass
try: try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment