From 303171de384191caaba13d0de3650566d2c0f8ba Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 28 Feb 2020 17:33:31 +0100
Subject: [PATCH] Use cuda

---
 src/pystencils_autodiff/backends/_torch_native.py | 3 ++-
 src/pystencils_autodiff/backends/astnodes.py      | 2 +-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 48417af..d43cb90 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -62,7 +62,8 @@ def create_autograd_function(autodiff_obj, use_cuda):
                 kwargs[field.name] = torch.zeros(
                     field.shape,
                     dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
-                    device=next(chain(args, kwargs.values())).device)
+                    device='cuda' if use_cuda else 'cpu')  # use device of tensor
+
         output_tensors = OrderedDict({f.name:
                                       field_to_tensor_dict.get(f, kwargs[f.name])
                                       for f in autodiff_obj.forward_output_fields})
diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index a040b06..b8487f1 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -55,7 +55,7 @@ class Header(JinjaCppFile):
 
 class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
     CLASS_TO_MEMBER_DICT = {
-        FieldPointerSymbol: "data_ptr<{dtype}>()",
+        FieldPointerSymbol: "data<{dtype}>()",
         FieldShapeSymbol: "size({dim})",
         FieldStrideSymbol: "strides()[{dim}]"
     }
-- 
GitLab