diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 75c023bc753d02b95f55190439f8b99262e23399..d5989709460bb2e103537f5ec231538e46aa501f 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -15,6 +15,7 @@ from pystencils import ( ) from pystencils.codegen import Kernel from pystencils.types import create_type, UserTypeSpec, PsType +from pystencilssfg.ir import SfgSourceFileType from ..context import SfgContext, SfgCursor from .custom import CustomGenerator @@ -716,6 +717,10 @@ class SfgFunctionSequencer(SfgFunctionSequencerBase): """Sequencer for constructing functions.""" def __call__(self, *args: SequencerArg) -> None: + # check if header is in HYBRID mode for c_interfacing enabled + if self._cursor.context.c_interfacing: + assert isinstance(self._cursor.context.header_file.file_type, SfgSourceFileType.HYBRID_HEADER) + """Populate the function body""" tree = make_sequence(*args) func = SfgFunction( diff --git a/src/pystencilssfg/emission/file_printer.py b/src/pystencilssfg/emission/file_printer.py index 74ec3a22edcdf5a9f6656ebd1f9d05155e2b1740..b2f6692b317a51d93129717a1caec410816c71f5 100644 --- a/src/pystencilssfg/emission/file_printer.py +++ b/src/pystencilssfg/emission/file_printer.py @@ -42,7 +42,28 @@ class SfgFilePrinter: if file.file_type == SfgSourceFileType.HEADER: code += "#pragma once\n\n" + includes = "" for header in file.includes: + incl = str(header) if header.system_header else f'"{str(header)}"' + includes += f"#include {incl}\n" + + if file.file_type == SfgSourceFileType.HYBRID_HEADER: + hybrid_includes = "" + for header in file.includes: + incl = str(header) if header.system_header else f'"{str(header)}"' + hybrid_includes += f"#include {incl}\n" + + # include different headers and wrap around guard distinguishing C++/C compilations + code += f""" + #ifdef __cplusplus\n + {includes} + #else\n + {hybrid_includes} + #endif\n""" + else: + code += includes + + for header in file.hybrid_includes: incl = str(header) if header.system_header else f'"{str(header)}"' code += f"#include {incl}\n" diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 5110fcf1f7e4cfdc04c77c5e5b969c5f4337c6ce..1eb30ce878e394f81ed96663bb5ba3c414c8b45b 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -85,9 +85,9 @@ class SourceFileGenerator: from .ir import SfgSourceFile, SfgSourceFileType - self._header_file = SfgSourceFile( - output_files[0].name, SfgSourceFileType.HEADER - ) + header_type = SfgSourceFileType.HYBRID_HEADER \ + if self._c_interfacing else SfgSourceFileType.HEADER + self._header_file = SfgSourceFile(output_files[0].name, header_type) self._impl_file: SfgSourceFile | None if self._header_only: @@ -164,6 +164,22 @@ class SourceFileGenerator: ) self._header_file.includes.sort(key=self._include_sort_key) + if self._c_interfacing: + # from: https://en.cppreference.com/w/cpp/header + c_compatibility_headers = [ + "<cassert", "<cctype>", "<cerrno>", "<cfenv>", "<cfloat>", + "<cinttypes>", "<climits>", "<clocale>", "<cmath>", + "<csetjmp>", "<csignal>", "<cstdarg>", "<cstddef>", "<cstdint>", + "<cstdio>", "<cstdlib>", "<cstring>", "<ctime>", "<cuchar>", + "<cwchar>", "<cwctype>" + ] + + for inc in self._header_file.includes: + if inc.system_header and inc.__str__() in c_compatibility_headers: + c_header = inc.__str__().replace("<c", "<") + self._header_file.hybrid_includes += HeaderFile( + c_header, system_header=True) + if self._impl_file is not None: impl_includes = collect_includes(self._impl_file) # If some header is already included by the generated header file, do not duplicate that inclusion diff --git a/src/pystencilssfg/ir/syntax.py b/src/pystencilssfg/ir/syntax.py index cdbd4c283b6bb0078e1051f89565b3b6b32d8d21..71d7fd04fd888da2a37dd40c1491cffa4106e18c 100644 --- a/src/pystencilssfg/ir/syntax.py +++ b/src/pystencilssfg/ir/syntax.py @@ -181,6 +181,7 @@ SfgNamespaceElement = ( class SfgSourceFileType(Enum): HEADER = auto() + HYBRID_HEADER = auto() TRANSLATION_UNIT = auto() @@ -200,6 +201,7 @@ class SfgSourceFile: self._file_type: SfgSourceFileType = file_type self._prelude: str | None = prelude self._includes: list[HeaderFile] = [] + self._hybrid_includes: list[HeaderFile] = [] self._elements: list[SfgNamespaceElement] = [] @property @@ -230,6 +232,15 @@ class SfgSourceFile: def includes(self, incl: Iterable[HeaderFile]): self._includes = list(incl) + @property + def hybrid_includes(self) -> list[HeaderFile]: + """Sequence of header files to be included at the top of this file""" + return self._hybrid_includes + + @hybrid_includes.setter + def hybrid_includes(self, incl: Iterable[HeaderFile]): + self._hybrid_includes = list(incl) + @property def elements(self) -> list[SfgNamespaceElement]: """Sequence of source elements comprising the body of this file""" diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index abe863a96252e6e1cd97e1c28ab62445ee9cdf89..135a54eed92e4ba214244c8f46323ea81f6610db 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -483,7 +483,7 @@ def includes(obj: ExprLike | PsType) -> set[HeaderFile]: case PsType(): headers = set(HeaderFile.parse(h) for h in obj.required_headers) if isinstance(obj, PsIntegerType): - headers.add(HeaderFile.parse("<cstdint>")) # TODO: switch for stdint.h + headers.add(HeaderFile.parse("<cstdint>")) return headers case SfgVar(_, dtype):