From 4a04ef24fc8673e1468f9a2c09b911001ea08261 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 7 Aug 2019 13:04:53 +0200
Subject: [PATCH] Define GpuPointerHolder only if HAS_PYCUDA

---
 tests/lbm/backends/_pytorch.py | 32 ++++++++++++++++++--------------
 1 file changed, 18 insertions(+), 14 deletions(-)

diff --git a/tests/lbm/backends/_pytorch.py b/tests/lbm/backends/_pytorch.py
index c131091..289fc0b 100644
--- a/tests/lbm/backends/_pytorch.py
+++ b/tests/lbm/backends/_pytorch.py
@@ -7,8 +7,9 @@ try:
     import pycuda.autoinit
     import pycuda.gpuarray
     import pycuda.driver
+    HAS_PYCUDA = True
 except Exception:
-    pass
+    HAS_PYCUDA = False
 
 
 def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict, forward_loop, backward_loop,
@@ -123,20 +124,23 @@ def gpuarray_to_tensor(gpuarray, context=None):
     return out
 
 
-class GpuPointerHolder(pycuda.driver.PointerHolderBase):
+if HAS_PYCUDA:
+    class GpuPointerHolder(pycuda.driver.PointerHolderBase):
 
-    def __init__(self, tensor):
-        super().__init__()
-        self.tensor = tensor
-        self.gpudata = tensor.data_ptr()
+        def __init__(self, tensor):
+            super().__init__()
+            self.tensor = tensor
+            self.gpudata = tensor.data_ptr()
 
-    def get_pointer(self):
-        return self.tensor.data_ptr()
+        def get_pointer(self):
+            return self.tensor.data_ptr()
 
-    def __int__(self):
-        return self.__index__()
+        def __int__(self):
+            return self.__index__()
 
-    # without an __index__ method, arithmetic calls to the GPUArray backed by this pointer fail
-    # not sure why, this needs to return some integer, apparently
-    def __index__(self):
-        return self.gpudata
+        # without an __index__ method, arithmetic calls to the GPUArray backed by this pointer fail
+        # not sure why, this needs to return some integer, apparently
+        def __index__(self):
+            return self.gpudata
+else:
+    GpuPointerHolder = None
-- 
GitLab