From 5ebe72cddd1645ebcca9174f1988924ae9423698 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 25 Oct 2019 19:01:02 +0200 Subject: [PATCH] Fix bugs for non-tensor arguments in _torch_native --- src/pystencils_autodiff/backends/_torch_native.py | 8 ++++---- src/pystencils_autodiff/framework_integration/astnodes.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 4bef573..e6b6389 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -39,11 +39,11 @@ def create_autograd_function(autodiff_obj, use_cuda): kwargs.update(class_kwargs) # TODO: drop contiguous requirement if use_cuda: - args = [a.contiguous().cuda() for a in args] - kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()} + args = [a.contiguous().cuda() if isinstance(a, torch.Tensor) else a for a in args] + kwargs = {k: v.contiguous().cuda() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} else: - args = [a.contiguous().cpu() for a in args] - kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()} + args = [a.contiguous().cpu() if isinstance(a, torch.Tensor) else a for a in args] + kwargs = {k: v.contiguous().cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} # assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields) # if not any(isinstance(s, sp.Symbol) for s in args[i].shape)) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 3e5e7a3..272e690 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -15,7 +15,6 @@ from functools import reduce from typing import Any, List, Set import jinja2 -import numpy as np import pystencils import sympy as sp @@ -268,7 +267,7 @@ copyParams.extent = {{{", ".join(reversed(self._shape))}}}; copyParams.kind = cudaMemcpyDeviceToDevice; cudaMemcpy3D(&{{copy_params}});""" # noqa elif self._texture.field.ndim == 2: - # cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind); + # noqa: cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind); return f"""cudaMemcpy2DToArray({array}, 0u, @@ -278,7 +277,7 @@ cudaMemcpy3D(&{{copy_params}});""" # noqa {self._texture.field.shape[-1]}, {self._texture.field.shape[-2]}, cudaMemcpyDeviceToDevice); - """ + """ # noqa else: raise NotImplementedError() -- GitLab