From 513acdfec36af16e6df1ff559c611a1da6704f40 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 7 Oct 2020 16:49:00 +0200
Subject: [PATCH] Allow return types for WrapperFunction

---
 .../framework_integration/astnodes.py         |  4 +++-
 .../framework_integration/printer.py          | 23 +++++++++++++++++--
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index eebf9b6..351709a 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -215,9 +215,11 @@ class FunctionCall(Node):
 
 class WrapperFunction(pystencils.astnodes.KernelFunction):
 
-    def __init__(self, body, function_name='wrapper', target='cpu', backend='c'):
+    def __init__(self, body, function_name='wrapper', target='cpu', backend='c', return_type=None, return_value=None):
         super().__init__(body, target, backend, compile_function=None, ghost_layers=0)
         self.function_name = function_name
+        self.return_type = return_type
+        self.return_value = return_value
 
 
 def generate_kernel_call(kernel_function):
diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py
index edc6edc..6e11cf6 100644
--- a/src/pystencils_autodiff/framework_integration/printer.py
+++ b/src/pystencils_autodiff/framework_integration/printer.py
@@ -3,6 +3,7 @@ import functools
 import sympy as sp
 
 import pystencils.backends.cbackend
+from pystencils.astnodes import KernelFunction
 from pystencils.data_types import TypedSymbol
 from pystencils.kernelparameters import FieldPointerSymbol
 from pystencils_autodiff.framework_integration.types import TemplateType
@@ -38,7 +39,9 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
         return "\n%s\n" % (''.join(block_contents.splitlines(True)))
 
     def _print_WrapperFunction(self, node):
-        super_result = super()._print_KernelFunction(node)
+        if node.return_value:
+            node._body._nodes.append(node.return_value)
+        super_result = self._print_KernelFunction_extended(node)
         if self._signatureOnly:
             super_result += ';'
         return super_result.replace('FUNC_PREFIX ', '')
@@ -46,7 +49,23 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
     def _print_TextureDeclaration(self, node):
         return str(node)
 
-    def _print_KernelFunction(self, node):
+    def _print_KernelFunction_extended(self, node: KernelFunction):
+        return_type = node.return_type if hasattr(node, 'return_type') and node.return_type else 'void'
+        function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}"
+                              for s in node.get_parameters() if hasattr(s.symbol, 'dtype')]
+        launch_bounds = ""
+        if self._dialect == 'cuda':
+            max_threads = node.indexing.max_threads_per_block()
+            if max_threads:
+                launch_bounds = f"__launch_bounds__({max_threads}) "
+        func_declaration = f"FUNC_PREFIX {launch_bounds} {self._print(return_type)} {node.function_name}({', '.join(function_arguments)})"  # noqa
+        if self._signatureOnly:
+            return func_declaration
+
+        body = self._print(node.body)
+        return func_declaration + "\n" + body
+
+    def _print_KernelFunction(self, node, return_type='void'):
         if node.backend == 'gpucuda':
             prefix = '#define FUNC_PREFIX static __global__\n'
             kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda', with_globals=False)
-- 
GitLab