diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py index 6808d4731d816cf32bab28a173908261d19df479..cc4b50e217d68a4bba28c2c95705464a91182212 100644 --- a/src/pystencils/backend/emission/base_printer.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -57,6 +57,7 @@ from ..extensions.foreign_ast import PsForeignExpression from ..memory import PsSymbol from ..constants import PsConstant from ...types import PsType +from ...codegen import Target if TYPE_CHECKING: from ...codegen import Kernel @@ -382,7 +383,7 @@ class BasePrinter(ABC): from ...codegen import GpuKernel sig_parts = [self._func_prefix] if self._func_prefix is not None else [] - if isinstance(func, GpuKernel): + if isinstance(func, GpuKernel) and func.target == Target.CUDA: sig_parts.append("__global__") sig_parts += ["void", func.name, f"({params_str})"] signature = " ".join(sig_parts)