From 513acdfec36af16e6df1ff559c611a1da6704f40 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 7 Oct 2020 16:49:00 +0200 Subject: [PATCH] Allow return types for WrapperFunction --- .../framework_integration/astnodes.py | 4 +++- .../framework_integration/printer.py | 23 +++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index eebf9b6..351709a 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -215,9 +215,11 @@ class FunctionCall(Node): class WrapperFunction(pystencils.astnodes.KernelFunction): - def __init__(self, body, function_name='wrapper', target='cpu', backend='c'): + def __init__(self, body, function_name='wrapper', target='cpu', backend='c', return_type=None, return_value=None): super().__init__(body, target, backend, compile_function=None, ghost_layers=0) self.function_name = function_name + self.return_type = return_type + self.return_value = return_value def generate_kernel_call(kernel_function): diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index edc6edc..6e11cf6 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -3,6 +3,7 @@ import functools import sympy as sp import pystencils.backends.cbackend +from pystencils.astnodes import KernelFunction from pystencils.data_types import TypedSymbol from pystencils.kernelparameters import FieldPointerSymbol from pystencils_autodiff.framework_integration.types import TemplateType @@ -38,7 +39,9 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): return "\n%s\n" % (''.join(block_contents.splitlines(True))) def _print_WrapperFunction(self, node): - super_result = super()._print_KernelFunction(node) + if node.return_value: + node._body._nodes.append(node.return_value) + super_result = self._print_KernelFunction_extended(node) if self._signatureOnly: super_result += ';' return super_result.replace('FUNC_PREFIX ', '') @@ -46,7 +49,23 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def _print_TextureDeclaration(self, node): return str(node) - def _print_KernelFunction(self, node): + def _print_KernelFunction_extended(self, node: KernelFunction): + return_type = node.return_type if hasattr(node, 'return_type') and node.return_type else 'void' + function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" + for s in node.get_parameters() if hasattr(s.symbol, 'dtype')] + launch_bounds = "" + if self._dialect == 'cuda': + max_threads = node.indexing.max_threads_per_block() + if max_threads: + launch_bounds = f"__launch_bounds__({max_threads}) " + func_declaration = f"FUNC_PREFIX {launch_bounds} {self._print(return_type)} {node.function_name}({', '.join(function_arguments)})" # noqa + if self._signatureOnly: + return func_declaration + + body = self._print(node.body) + return func_declaration + "\n" + body + + def _print_KernelFunction(self, node, return_type='void'): if node.backend == 'gpucuda': prefix = '#define FUNC_PREFIX static __global__\n' kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda', with_globals=False) -- GitLab