From 4d0b1d3dd86b51372ad31d37f4737a6a44d83906 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 15 Jan 2020 19:26:02 +0100 Subject: [PATCH] Return tensors in run_kernel --- .../framework_integration/datahandling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/framework_integration/datahandling.py b/src/pystencils_autodiff/framework_integration/datahandling.py index 071c878..98d10d1 100644 --- a/src/pystencils_autodiff/framework_integration/datahandling.py +++ b/src/pystencils_autodiff/framework_integration/datahandling.py @@ -12,7 +12,7 @@ try: import torch except ImportError: torch = None -from typing import Sequence, Union +from typing import Sequence, Tuple, Union import numpy as np @@ -51,7 +51,7 @@ class MultiShapeDatahandling(pystencils.datahandling.SerialDataHandling): opencl_ctx, array_handler=None) - def add_arrays(self, description: str, spatial_shape=None): + def add_arrays(self, description: str, spatial_shape=None) -> Tuple[pystencils.Field]: from pystencils.field import _parse_part1, _parse_description if ':' in description: @@ -184,7 +184,8 @@ class PyTorchDataHandling(MultiShapeDatahandling): def run_kernel(self, kernel_function, **kwargs): arrays = self.gpu_arrays if self.default_target == 'gpu' else self.cpu_arrays - kernel_function(**arrays, **kwargs) + rtn = kernel_function(**arrays, **kwargs) + return rtn def require_autograd(self, bool_val, *names): for n in names: -- GitLab