Skip to content
Snippets Groups Projects
Commit 4d0b1d3d authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Return tensors in run_kernel

parent d15f54b3
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ try:
import torch
except ImportError:
torch = None
from typing import Sequence, Union
from typing import Sequence, Tuple, Union
import numpy as np
......@@ -51,7 +51,7 @@ class MultiShapeDatahandling(pystencils.datahandling.SerialDataHandling):
opencl_ctx,
array_handler=None)
def add_arrays(self, description: str, spatial_shape=None):
def add_arrays(self, description: str, spatial_shape=None) -> Tuple[pystencils.Field]:
from pystencils.field import _parse_part1, _parse_description
if ':' in description:
......@@ -184,7 +184,8 @@ class PyTorchDataHandling(MultiShapeDatahandling):
def run_kernel(self, kernel_function, **kwargs):
arrays = self.gpu_arrays if self.default_target == 'gpu' else self.cpu_arrays
kernel_function(**arrays, **kwargs)
rtn = kernel_function(**arrays, **kwargs)
return rtn
def require_autograd(self, bool_val, *names):
for n in names:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment