From 5ebe72cddd1645ebcca9174f1988924ae9423698 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 25 Oct 2019 19:01:02 +0200
Subject: [PATCH] Fix bugs for non-tensor arguments in _torch_native

---
 src/pystencils_autodiff/backends/_torch_native.py         | 8 ++++----
 src/pystencils_autodiff/framework_integration/astnodes.py | 5 ++---
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py
index 4bef573..e6b6389 100644
--- a/src/pystencils_autodiff/backends/_torch_native.py
+++ b/src/pystencils_autodiff/backends/_torch_native.py
@@ -39,11 +39,11 @@ def create_autograd_function(autodiff_obj, use_cuda):
         kwargs.update(class_kwargs)
         # TODO: drop contiguous requirement
         if use_cuda:
-            args = [a.contiguous().cuda() for a in args]
-            kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()}
+            args = [a.contiguous().cuda() if isinstance(a, torch.Tensor) else a for a in args]
+            kwargs = {k: v.contiguous().cuda() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
         else:
-            args = [a.contiguous().cpu() for a in args]
-            kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()}
+            args = [a.contiguous().cpu() if isinstance(a, torch.Tensor) else a for a in args]
+            kwargs = {k: v.contiguous().cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
 
         # assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)
         # if not any(isinstance(s, sp.Symbol) for s in args[i].shape))
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 3e5e7a3..272e690 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -15,7 +15,6 @@ from functools import reduce
 from typing import Any, List, Set
 
 import jinja2
-import numpy as np
 
 import pystencils
 import sympy as sp
@@ -268,7 +267,7 @@ copyParams.extent = {{{", ".join(reversed(self._shape))}}};
 copyParams.kind = cudaMemcpyDeviceToDevice;
 cudaMemcpy3D(&{{copy_params}});"""  # noqa
         elif self._texture.field.ndim == 2:
-            # cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind);
+            # noqa: cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind);
 
             return f"""cudaMemcpy2DToArray({array},
                     0u,
@@ -278,7 +277,7 @@ cudaMemcpy3D(&{{copy_params}});"""  # noqa
                     {self._texture.field.shape[-1]},
                     {self._texture.field.shape[-2]},
                     cudaMemcpyDeviceToDevice);
- """
+ """  # noqa
         else:
             raise NotImplementedError()
 
-- 
GitLab