Skip to content
Snippets Groups Projects
Commit 5ebe72cd authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix bugs for non-tensor arguments in _torch_native

parent e769999f
Branches
Tags
No related merge requests found
Pipeline #19149 failed
......@@ -39,11 +39,11 @@ def create_autograd_function(autodiff_obj, use_cuda):
kwargs.update(class_kwargs)
# TODO: drop contiguous requirement
if use_cuda:
args = [a.contiguous().cuda() for a in args]
kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()}
args = [a.contiguous().cuda() if isinstance(a, torch.Tensor) else a for a in args]
kwargs = {k: v.contiguous().cuda() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
else:
args = [a.contiguous().cpu() for a in args]
kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()}
args = [a.contiguous().cpu() if isinstance(a, torch.Tensor) else a for a in args]
kwargs = {k: v.contiguous().cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
# assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)
# if not any(isinstance(s, sp.Symbol) for s in args[i].shape))
......
......@@ -15,7 +15,6 @@ from functools import reduce
from typing import Any, List, Set
import jinja2
import numpy as np
import pystencils
import sympy as sp
......@@ -268,7 +267,7 @@ copyParams.extent = {{{", ".join(reversed(self._shape))}}};
copyParams.kind = cudaMemcpyDeviceToDevice;
cudaMemcpy3D(&{{copy_params}});""" # noqa
elif self._texture.field.ndim == 2:
# cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind);
# noqa: cudaMemcpy2DToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind);
return f"""cudaMemcpy2DToArray({array},
0u,
......@@ -278,7 +277,7 @@ cudaMemcpy3D(&{{copy_params}});""" # noqa
{self._texture.field.shape[-1]},
{self._texture.field.shape[-2]},
cudaMemcpyDeviceToDevice);
"""
""" # noqa
else:
raise NotImplementedError()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment