From 4d0b1d3dd86b51372ad31d37f4737a6a44d83906 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 15 Jan 2020 19:26:02 +0100
Subject: [PATCH] Return tensors in run_kernel

---
 .../framework_integration/datahandling.py                  | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/src/pystencils_autodiff/framework_integration/datahandling.py b/src/pystencils_autodiff/framework_integration/datahandling.py
index 071c878..98d10d1 100644
--- a/src/pystencils_autodiff/framework_integration/datahandling.py
+++ b/src/pystencils_autodiff/framework_integration/datahandling.py
@@ -12,7 +12,7 @@ try:
     import torch
 except ImportError:
     torch = None
-from typing import Sequence, Union
+from typing import Sequence, Tuple, Union
 
 import numpy as np
 
@@ -51,7 +51,7 @@ class MultiShapeDatahandling(pystencils.datahandling.SerialDataHandling):
             opencl_ctx,
             array_handler=None)
 
-    def add_arrays(self, description: str, spatial_shape=None):
+    def add_arrays(self, description: str, spatial_shape=None) -> Tuple[pystencils.Field]:
         from pystencils.field import _parse_part1, _parse_description
 
         if ':' in description:
@@ -184,7 +184,8 @@ class PyTorchDataHandling(MultiShapeDatahandling):
 
     def run_kernel(self, kernel_function, **kwargs):
         arrays = self.gpu_arrays if self.default_target == 'gpu' else self.cpu_arrays
-        kernel_function(**arrays, **kwargs)
+        rtn = kernel_function(**arrays, **kwargs)
+        return rtn
 
     def require_autograd(self, bool_val, *names):
         for n in names:
-- 
GitLab