diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 4bef57312a4ca5c8e3e4d0f52330e3a3b60cf530..e6b63897c3629db50f408f003a5c746c944a8a5e 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -39,11 +39,11 @@ def create_autograd_function(autodiff_obj, use_cuda): kwargs.update(class_kwargs) # TODO: drop contiguous requirement if use_cuda: - args = [a.contiguous().cuda() for a in args] - kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()} + args = [a.contiguous().cuda() if isinstance(a, torch.Tensor) else a for a in args] + kwargs = {k: v.contiguous().cuda() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} else: - args = [a.contiguous().cpu() for a in args] - kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()} + args = [a.contiguous().cpu() if isinstance(a, torch.Tensor) else a for a in args] + kwargs = {k: v.contiguous().cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} # assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields) # if not any(isinstance(s, sp.Symbol) for s in args[i].shape)) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 3e5e7a35366b371df389ee04e1a708ac5fe84094..272e69007f3b33720d5bd1aec640291fb2ab0b0e 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -15,7 +15,6 @@ from functools import reduce from typing import Any, List, Set import jinja2 -import numpy as np import pystencils import sympy as sp @@ -268,7 +267,7 @@ copyParams.extent = {{{", ".join(reversed(self._shape))}}}; copyParams.kind = cudaMemcpyDeviceToDevice; cudaMemcpy3D(&{{copy_params}});""" # noqa elif self._texture.field.ndim == 2: - # cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind); + # noqa: cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind); return f"""cudaMemcpy2DToArray({array}, 0u, @@ -278,7 +277,7 @@ cudaMemcpy3D(&{{copy_params}});""" # noqa {self._texture.field.shape[-1]}, {self._texture.field.shape[-2]}, cudaMemcpyDeviceToDevice); - """ + """ # noqa else: raise NotImplementedError()