Skip to content
Snippets Groups Projects
Commit d6df3ac1 authored by Markus Holzer's avatar Markus Holzer
Browse files

Small check

parent cd234778
No related branches found
No related tags found
No related merge requests found
Pipeline #53772 passed
...@@ -28,7 +28,7 @@ class GPUArrayHandler: ...@@ -28,7 +28,7 @@ class GPUArrayHandler:
@staticmethod @staticmethod
def to_gpu(numpy_array): def to_gpu(numpy_array):
swaps = _get_index_swaps(numpy_array) swaps = _get_index_swaps(numpy_array)
if numpy_array.base is not None: if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray):
with cp.cuda.Device(pystencils.GPU_DEVICE): with cp.cuda.Device(pystencils.GPU_DEVICE):
gpu_array = cp.asarray(numpy_array.base) gpu_array = cp.asarray(numpy_array.base)
for a, b in reversed(swaps): for a, b in reversed(swaps):
...@@ -39,7 +39,7 @@ class GPUArrayHandler: ...@@ -39,7 +39,7 @@ class GPUArrayHandler:
@staticmethod @staticmethod
def upload(array, numpy_array): def upload(array, numpy_array):
if numpy_array.base is not None: if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray):
with cp.cuda.Device(pystencils.GPU_DEVICE): with cp.cuda.Device(pystencils.GPU_DEVICE):
array.base.set(numpy_array.base) array.base.set(numpy_array.base)
else: else:
...@@ -48,7 +48,7 @@ class GPUArrayHandler: ...@@ -48,7 +48,7 @@ class GPUArrayHandler:
@staticmethod @staticmethod
def download(array, numpy_array): def download(array, numpy_array):
if numpy_array.base is not None: if numpy_array.base is not None and isinstance(numpy_array.base, np.ndarray):
with cp.cuda.Device(pystencils.GPU_DEVICE): with cp.cuda.Device(pystencils.GPU_DEVICE):
numpy_array.base[:] = array.base.get() numpy_array.base[:] = array.base.get()
else: else:
...@@ -92,7 +92,7 @@ class GPUNotAvailableHandler: ...@@ -92,7 +92,7 @@ class GPUNotAvailableHandler:
def _get_index_swaps(array): def _get_index_swaps(array):
swaps = [] swaps = []
if array.base is not None: if array.base is not None and isinstance(array.base, np.ndarray):
for stride in array.base.strides: for stride in array.base.strides:
index_base = array.base.strides.index(stride) index_base = array.base.strides.index(stride)
index_view = array.strides.index(stride) index_view = array.strides.index(stride)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment