Skip to content
Snippets Groups Projects
Commit fddcf9dc authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Implement hybrid include mechanism for C interfacing

parent 86c46699
Branches
No related tags found
1 merge request!25Draft: C Interfacing
Pipeline #75815 failed
......@@ -15,6 +15,7 @@ from pystencils import (
)
from pystencils.codegen import Kernel
from pystencils.types import create_type, UserTypeSpec, PsType
from pystencilssfg.ir import SfgSourceFileType
from ..context import SfgContext, SfgCursor
from .custom import CustomGenerator
......@@ -716,6 +717,10 @@ class SfgFunctionSequencer(SfgFunctionSequencerBase):
"""Sequencer for constructing functions."""
def __call__(self, *args: SequencerArg) -> None:
# check if header is in HYBRID mode for c_interfacing enabled
if self._cursor.context.c_interfacing:
assert isinstance(self._cursor.context.header_file.file_type, SfgSourceFileType.HYBRID_HEADER)
"""Populate the function body"""
tree = make_sequence(*args)
func = SfgFunction(
......
......@@ -42,7 +42,28 @@ class SfgFilePrinter:
if file.file_type == SfgSourceFileType.HEADER:
code += "#pragma once\n\n"
includes = ""
for header in file.includes:
incl = str(header) if header.system_header else f'"{str(header)}"'
includes += f"#include {incl}\n"
if file.file_type == SfgSourceFileType.HYBRID_HEADER:
hybrid_includes = ""
for header in file.includes:
incl = str(header) if header.system_header else f'"{str(header)}"'
hybrid_includes += f"#include {incl}\n"
# include different headers and wrap around guard distinguishing C++/C compilations
code += f"""
#ifdef __cplusplus\n
{includes}
#else\n
{hybrid_includes}
#endif\n"""
else:
code += includes
for header in file.hybrid_includes:
incl = str(header) if header.system_header else f'"{str(header)}"'
code += f"#include {incl}\n"
......
......@@ -85,9 +85,9 @@ class SourceFileGenerator:
from .ir import SfgSourceFile, SfgSourceFileType
self._header_file = SfgSourceFile(
output_files[0].name, SfgSourceFileType.HEADER
)
header_type = SfgSourceFileType.HYBRID_HEADER \
if self._c_interfacing else SfgSourceFileType.HEADER
self._header_file = SfgSourceFile(output_files[0].name, header_type)
self._impl_file: SfgSourceFile | None
if self._header_only:
......@@ -164,6 +164,22 @@ class SourceFileGenerator:
)
self._header_file.includes.sort(key=self._include_sort_key)
if self._c_interfacing:
# from: https://en.cppreference.com/w/cpp/header
c_compatibility_headers = [
"<cassert", "<cctype>", "<cerrno>", "<cfenv>", "<cfloat>",
"<cinttypes>", "<climits>", "<clocale>", "<cmath>",
"<csetjmp>", "<csignal>", "<cstdarg>", "<cstddef>", "<cstdint>",
"<cstdio>", "<cstdlib>", "<cstring>", "<ctime>", "<cuchar>",
"<cwchar>", "<cwctype>"
]
for inc in self._header_file.includes:
if inc.system_header and inc.__str__() in c_compatibility_headers:
c_header = inc.__str__().replace("<c", "<")
self._header_file.hybrid_includes += HeaderFile(
c_header, system_header=True)
if self._impl_file is not None:
impl_includes = collect_includes(self._impl_file)
# If some header is already included by the generated header file, do not duplicate that inclusion
......
......@@ -181,6 +181,7 @@ SfgNamespaceElement = (
class SfgSourceFileType(Enum):
HEADER = auto()
HYBRID_HEADER = auto()
TRANSLATION_UNIT = auto()
......@@ -200,6 +201,7 @@ class SfgSourceFile:
self._file_type: SfgSourceFileType = file_type
self._prelude: str | None = prelude
self._includes: list[HeaderFile] = []
self._hybrid_includes: list[HeaderFile] = []
self._elements: list[SfgNamespaceElement] = []
@property
......@@ -230,6 +232,15 @@ class SfgSourceFile:
def includes(self, incl: Iterable[HeaderFile]):
self._includes = list(incl)
@property
def hybrid_includes(self) -> list[HeaderFile]:
"""Sequence of header files to be included at the top of this file"""
return self._hybrid_includes
@hybrid_includes.setter
def hybrid_includes(self, incl: Iterable[HeaderFile]):
self._hybrid_includes = list(incl)
@property
def elements(self) -> list[SfgNamespaceElement]:
"""Sequence of source elements comprising the body of this file"""
......
......@@ -483,7 +483,7 @@ def includes(obj: ExprLike | PsType) -> set[HeaderFile]:
case PsType():
headers = set(HeaderFile.parse(h) for h in obj.required_headers)
if isinstance(obj, PsIntegerType):
headers.add(HeaderFile.parse("<cstdint>")) # TODO: switch for stdint.h
headers.add(HeaderFile.parse("<cstdint>"))
return headers
case SfgVar(_, dtype):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment