diff --git a/src/pystencils_autodiff/framework_integration/datahandling.py b/src/pystencils_autodiff/framework_integration/datahandling.py index 071c8783e9fc40cfecd4477437f7f789bacc3386..98d10d15a58b2708be0227eb304a2eb4985fedb3 100644 --- a/src/pystencils_autodiff/framework_integration/datahandling.py +++ b/src/pystencils_autodiff/framework_integration/datahandling.py @@ -12,7 +12,7 @@ try: import torch except ImportError: torch = None -from typing import Sequence, Union +from typing import Sequence, Tuple, Union import numpy as np @@ -51,7 +51,7 @@ class MultiShapeDatahandling(pystencils.datahandling.SerialDataHandling): opencl_ctx, array_handler=None) - def add_arrays(self, description: str, spatial_shape=None): + def add_arrays(self, description: str, spatial_shape=None) -> Tuple[pystencils.Field]: from pystencils.field import _parse_part1, _parse_description if ':' in description: @@ -184,7 +184,8 @@ class PyTorchDataHandling(MultiShapeDatahandling): def run_kernel(self, kernel_function, **kwargs): arrays = self.gpu_arrays if self.default_target == 'gpu' else self.cpu_arrays - kernel_function(**arrays, **kwargs) + rtn = kernel_function(**arrays, **kwargs) + return rtn def require_autograd(self, bool_val, *names): for n in names: