Skip to content
Snippets Groups Projects
Commit 4a04ef24 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Define GpuPointerHolder only if HAS_PYCUDA

parent b80e39f7
No related branches found
No related tags found
No related merge requests found
......@@ -7,8 +7,9 @@ try:
import pycuda.autoinit
import pycuda.gpuarray
import pycuda.driver
HAS_PYCUDA = True
except Exception:
pass
HAS_PYCUDA = False
def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict, forward_loop, backward_loop,
......@@ -123,20 +124,23 @@ def gpuarray_to_tensor(gpuarray, context=None):
return out
class GpuPointerHolder(pycuda.driver.PointerHolderBase):
if HAS_PYCUDA:
class GpuPointerHolder(pycuda.driver.PointerHolderBase):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
self.gpudata = tensor.data_ptr()
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
self.gpudata = tensor.data_ptr()
def get_pointer(self):
return self.tensor.data_ptr()
def get_pointer(self):
return self.tensor.data_ptr()
def __int__(self):
return self.__index__()
def __int__(self):
return self.__index__()
# without an __index__ method, arithmetic calls to the GPUArray backed by this pointer fail
# not sure why, this needs to return some integer, apparently
def __index__(self):
return self.gpudata
# without an __index__ method, arithmetic calls to the GPUArray backed by this pointer fail
# not sure why, this needs to return some integer, apparently
def __index__(self):
return self.gpudata
else:
GpuPointerHolder = None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment