diff --git a/pystencils/gpu/gpu_array_handler.py b/pystencils/gpu/gpu_array_handler.py index f25ba5ffc2de19378bbcecd3ac7d4cbee9f4cf83..9031036f42a29f5a8adb2e08905c990ec8b441e8 100644 --- a/pystencils/gpu/gpu_array_handler.py +++ b/pystencils/gpu/gpu_array_handler.py @@ -28,7 +28,7 @@ class GPUArrayHandler: @staticmethod def to_gpu(numpy_array): swaps = _get_index_swaps(numpy_array) - if numpy_array.base is not None: + if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray): with cp.cuda.Device(pystencils.GPU_DEVICE): gpu_array = cp.asarray(numpy_array.base) for a, b in reversed(swaps): @@ -39,7 +39,7 @@ class GPUArrayHandler: @staticmethod def upload(array, numpy_array): - if numpy_array.base is not None: + if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray): with cp.cuda.Device(pystencils.GPU_DEVICE): array.base.set(numpy_array.base) else: @@ -48,7 +48,7 @@ class GPUArrayHandler: @staticmethod def download(array, numpy_array): - if numpy_array.base is not None: + if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray): with cp.cuda.Device(pystencils.GPU_DEVICE): numpy_array.base[:] = array.base.get() else: @@ -92,7 +92,7 @@ class GPUNotAvailableHandler: def _get_index_swaps(array): swaps = [] - if array.base is not None: + if array.base is not None and isinstance(array.base, np.ndarray): for stride in array.base.strides: index_base = array.base.strides.index(stride) index_view = array.strides.index(stride)