From 197a3d9d05677de9c6e14c2ffb3a871aecab59e4 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 6 Mar 2025 17:47:51 +0100
Subject: [PATCH] Add composer for "extern C" prefix

---
 src/pystencilssfg/composer/basic_composer.py | 8 +++++++-
 src/pystencilssfg/emission/file_printer.py   | 5 ++++-
 src/pystencilssfg/ir/entities.py             | 6 ++++++
 src/pystencilssfg/lang/expressions.py        | 2 +-
 4 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 31337a6..9e58d78 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -91,6 +91,7 @@ class KernelsAdder:
         self._cursor = cursor
         self._kernel_namespace = knamespace
         self._inline: bool = False
+        self._externC: bool = False
         self._loc: SfgNamespaceBlock | None = None
 
     def inline(self) -> KernelsAdder:
@@ -98,6 +99,11 @@ class KernelsAdder:
         self._inline = True
         return self
 
+    def externC(self) -> KernelsAdder:
+        """Generate kernel definitions ``extern "C"`` in the header file."""
+        self._externC = True
+        return self
+
     def add(self, kernel: Kernel, name: str | None = None):
         """Adds an existing pystencils AST to this namespace.
         If a name is specified, the AST's function name is changed."""
@@ -116,7 +122,7 @@ class KernelsAdder:
             kernel.name = kernel_name
 
         khandle = SfgKernelHandle(
-            kernel_name, self._kernel_namespace, kernel, inline=self._inline
+            kernel_name, self._kernel_namespace, kernel, inline=self._inline, externC=self._externC
         )
         self._kernel_namespace.add_kernel(khandle)
 
diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py
index 648e419..d6b9296 100644
--- a/src/pystencilssfg/emission/file_printer.py
+++ b/src/pystencilssfg/emission/file_printer.py
@@ -84,9 +84,12 @@ class SfgFilePrinter:
     ) -> str:
         match declared_entity:
             case SfgKernelHandle(kernel):
+                func_prefix = "extern C" if declared_entity.externC else ""
+                func_prefix += " inline" if declared_entity.inline else ""
+
                 kernel_printer = CAstPrinter(
                     indent_width=self._indent_width,
-                    func_prefix="inline" if declared_entity.inline else None,
+                    func_prefix=func_prefix,
                 )
                 return kernel_printer.print_signature(kernel) + ";"
 
diff --git a/src/pystencilssfg/ir/entities.py b/src/pystencilssfg/ir/entities.py
index 0edde22..850bbde 100644
--- a/src/pystencilssfg/ir/entities.py
+++ b/src/pystencilssfg/ir/entities.py
@@ -147,6 +147,7 @@ class SfgKernelHandle(SfgCodeEntity):
         namespace: SfgKernelNamespace,
         kernel: Kernel,
         inline: bool = False,
+        externC: bool = False,
     ):
         super().__init__(name, namespace)
 
@@ -154,6 +155,7 @@ class SfgKernelHandle(SfgCodeEntity):
         self._parameters = [SfgKernelParamVar(p) for p in kernel.parameters]
 
         self._inline: bool = inline
+        self._externC: bool = externC
 
         self._scalar_params: set[SfgVar] = set()
         self._fields: set[Field] = set()
@@ -188,6 +190,10 @@ class SfgKernelHandle(SfgCodeEntity):
     def inline(self) -> bool:
         return self._inline
 
+    @property
+    def externC(self) -> bool:
+        return self._externC
+
 
 class SfgKernelNamespace(SfgNamespace):
     """A namespace grouping together a number of kernels."""
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 135a54e..abe863a 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -483,7 +483,7 @@ def includes(obj: ExprLike | PsType) -> set[HeaderFile]:
         case PsType():
             headers = set(HeaderFile.parse(h) for h in obj.required_headers)
             if isinstance(obj, PsIntegerType):
-                headers.add(HeaderFile.parse("<cstdint>"))
+                headers.add(HeaderFile.parse("<cstdint>"))  # TODO: switch for stdint.h
             return headers
 
         case SfgVar(_, dtype):
-- 
GitLab