Skip to content
Snippets Groups Projects
Commit bc8c41d3 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Avoid importing of tensorflow/torch for checking arrays

parent 053a14ff
Branches
No related merge requests found
......@@ -50,15 +50,17 @@ def get_type_of_arrays(*args):
except Exception:
pass
try:
import torch
if any(isinstance(a, torch.Tensor) for a in args):
return NdArrayType.TORCH
if any('torch' in str(type(a)) for a in args):
import torch
if any(isinstance(a, torch.Tensor) for a in args):
return NdArrayType.TORCH
except Exception as e:
print(e)
try:
from tensorflow import Tensor
if any(isinstance(a, Tensor) for a in args):
return NdArrayType.TENSORFLOW
if any('tensorflow' in str(type(a)) for a in args):
from tensorflow import Tensor
if any(isinstance(a, Tensor) for a in args):
return NdArrayType.TENSORFLOW
except Exception:
pass
try:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment