diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index eebf9b67bc419b036a072b581d3d7ea06864f9e6..351709aeed813468ed2931c327f169c9569cd588 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -215,9 +215,11 @@ class FunctionCall(Node): class WrapperFunction(pystencils.astnodes.KernelFunction): - def __init__(self, body, function_name='wrapper', target='cpu', backend='c'): + def __init__(self, body, function_name='wrapper', target='cpu', backend='c', return_type=None, return_value=None): super().__init__(body, target, backend, compile_function=None, ghost_layers=0) self.function_name = function_name + self.return_type = return_type + self.return_value = return_value def generate_kernel_call(kernel_function): diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index edc6edc8e50b7c73af9b42c1dfa813545f7478d0..6e11cf684a67651fd1e45c92638d7321f80002cc 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -3,6 +3,7 @@ import functools import sympy as sp import pystencils.backends.cbackend +from pystencils.astnodes import KernelFunction from pystencils.data_types import TypedSymbol from pystencils.kernelparameters import FieldPointerSymbol from pystencils_autodiff.framework_integration.types import TemplateType @@ -38,7 +39,9 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): return "\n%s\n" % (''.join(block_contents.splitlines(True))) def _print_WrapperFunction(self, node): - super_result = super()._print_KernelFunction(node) + if node.return_value: + node._body._nodes.append(node.return_value) + super_result = self._print_KernelFunction_extended(node) if self._signatureOnly: super_result += ';' return super_result.replace('FUNC_PREFIX ', '') @@ -46,7 +49,23 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def _print_TextureDeclaration(self, node): return str(node) - def _print_KernelFunction(self, node): + def _print_KernelFunction_extended(self, node: KernelFunction): + return_type = node.return_type if hasattr(node, 'return_type') and node.return_type else 'void' + function_arguments = [f"{self._print(s.symbol.dtype)} {s.symbol.name}" + for s in node.get_parameters() if hasattr(s.symbol, 'dtype')] + launch_bounds = "" + if self._dialect == 'cuda': + max_threads = node.indexing.max_threads_per_block() + if max_threads: + launch_bounds = f"__launch_bounds__({max_threads}) " + func_declaration = f"FUNC_PREFIX {launch_bounds} {self._print(return_type)} {node.function_name}({', '.join(function_arguments)})" # noqa + if self._signatureOnly: + return func_declaration + + body = self._print(node.body) + return func_declaration + "\n" + body + + def _print_KernelFunction(self, node, return_type='void'): if node.backend == 'gpucuda': prefix = '#define FUNC_PREFIX static __global__\n' kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda', with_globals=False)