From f76aa82753edee6fe79ee8c167e46703d7b74962 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 27 Jan 2025 16:59:09 +0100 Subject: [PATCH] fix __global__ in gpu kernels --- src/pystencils/backend/emission/base_printer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py index adb9c232b..6808d4731 100644 --- a/src/pystencils/backend/emission/base_printer.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -375,11 +375,16 @@ class BasePrinter(ABC): ) def print_signature(self, func: Kernel) -> str: - prefix = self._func_prefix params_str = ", ".join( f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters ) - sig_parts = ([prefix] if prefix is not None else []) + ["void", func.name, f"({params_str})"] + + from ...codegen import GpuKernel + + sig_parts = [self._func_prefix] if self._func_prefix is not None else [] + if isinstance(func, GpuKernel): + sig_parts.append("__global__") + sig_parts += ["void", func.name, f"({params_str})"] signature = " ".join(sig_parts) return signature -- GitLab