From 303171de384191caaba13d0de3650566d2c0f8ba Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 28 Feb 2020 17:33:31 +0100 Subject: [PATCH] Use cuda --- src/pystencils_autodiff/backends/_torch_native.py | 3 ++- src/pystencils_autodiff/backends/astnodes.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 48417af..d43cb90 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -62,7 +62,8 @@ def create_autograd_function(autodiff_obj, use_cuda): kwargs[field.name] = torch.zeros( field.shape, dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), - device=next(chain(args, kwargs.values())).device) + device='cuda' if use_cuda else 'cpu') # use device of tensor + output_tensors = OrderedDict({f.name: field_to_tensor_dict.get(f, kwargs[f.name]) for f in autodiff_obj.forward_output_fields}) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index a040b06..b8487f1 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -55,7 +55,7 @@ class Header(JinjaCppFile): class TorchTensorDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { - FieldPointerSymbol: "data_ptr<{dtype}>()", + FieldPointerSymbol: "data<{dtype}>()", FieldShapeSymbol: "size({dim})", FieldStrideSymbol: "strides()[{dim}]" } -- GitLab