Skip to content
Snippets Groups Projects
Commit af0e3408 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix header collection

parent cb4a449d
No related branches found
No related tags found
No related merge requests found
Pipeline #58337 passed
from typing import Generator, Sequence
from .configuration import SfgCodeStyle
from .visitors import CollectIncludes
from .source_components import (
SfgHeaderInclude,
SfgKernelNamespace,
......@@ -156,9 +155,6 @@ class SfgContext:
self._functions[func.name] = func
self._declarations_ordered.append(func)
for incl in CollectIncludes().visit(func):
self.add_include(incl)
# ----------------------------------------------------------------------------------------------
# Classes
# ----------------------------------------------------------------------------------------------
......@@ -176,9 +172,6 @@ class SfgContext:
self._classes[cls.class_name] = cls
self._declarations_ordered.append(cls)
for incl in CollectIncludes().visit(cls):
self.add_include(incl)
# ----------------------------------------------------------------------------------------------
# Declarations in order of addition
# ----------------------------------------------------------------------------------------------
......
......@@ -2,6 +2,7 @@ from os import path, makedirs
from ..configuration import SfgOutputSpec
from ..context import SfgContext
from .prepare import prepare_context
from .printers import SfgHeaderPrinter, SfgImplPrinter
from .clang_format import invoke_clang_format
......@@ -20,10 +21,12 @@ class HeaderSourcePairEmitter:
def output_files(self) -> tuple[str, str]:
return (
path.join(self._output_directory, self._header_filename),
path.join(self._output_directory, self._impl_filename)
path.join(self._output_directory, self._impl_filename),
)
def write_files(self, ctx: SfgContext):
ctx = prepare_context(ctx)
header_printer = SfgHeaderPrinter(ctx, self._ospec)
impl_printer = SfgImplPrinter(ctx, self._ospec)
......@@ -35,8 +38,8 @@ class HeaderSourcePairEmitter:
makedirs(self._output_directory, exist_ok=True)
with open(self._ospec.get_header_filepath(), 'w') as headerfile:
with open(self._ospec.get_header_filepath(), "w") as headerfile:
headerfile.write(header)
with open(self._ospec.get_impl_filepath(), 'w') as cppfile:
with open(self._ospec.get_impl_filepath(), "w") as cppfile:
cppfile.write(impl)
from ..context import SfgContext
from ..visitors import CollectIncludes
def prepare_context(ctx: SfgContext):
"""Prepares a populated context for printing. Make sure to run this function on the
[SfgContext][pystencilssfg.SfgContext] before passing it to a printer.
Steps:
- Collection of includes: All defined functions and classes are traversed to collect all required
header includes
"""
# Collect all includes
required_includes = CollectIncludes().visit(ctx)
for incl in required_includes:
ctx.add_include(incl)
return ctx
......@@ -440,9 +440,14 @@ class SfgClass:
def members(
self, visibility: SfgVisibility | None = None
) -> Generator[SfgClassMember, None, None]:
yield from chain.from_iterable(
b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
)
if visibility is None:
yield from chain.from_iterable(
b.members() for b in self._blocks
)
else:
yield from chain.from_iterable(
b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
)
def definitions(
self, visibility: SfgVisibility | None = None
......
......@@ -7,7 +7,14 @@ from functools import reduce
from .dispatcher import visitor
from ..exceptions import SfgException
from ..tree import SfgCallTreeNode
from ..source_components import SfgFunction, SfgClass, SfgConstructor, SfgMemberVariable
from ..source_components import (
SfgFunction,
SfgClass,
SfgConstructor,
SfgMemberVariable,
SfgInClassDefinition,
)
from ..context import SfgContext
if TYPE_CHECKING:
from ..source_components import SfgHeaderInclude
......@@ -18,10 +25,23 @@ class CollectIncludes:
def visit(self, obj: object) -> set[SfgHeaderInclude]:
raise SfgException(f"Can't collect includes from object of type {type(obj)}")
@visit.case(SfgContext)
def context(self, ctx: SfgContext) -> set[SfgHeaderInclude]:
includes = set()
for func in ctx.functions():
includes |= self.visit(func)
for cls in ctx.classes():
includes |= self.visit(cls)
return includes
@visit.case(SfgCallTreeNode)
def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]:
return reduce(
lambda accu, child: accu | self.visit(child), node.children, node.required_includes
lambda accu, child: accu | self.visit(child),
node.children,
node.required_includes,
)
@visit.case(SfgFunction)
......@@ -43,3 +63,7 @@ class CollectIncludes:
@visit.case(SfgMemberVariable)
def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]:
return var.required_includes
@visit.case(SfgInClassDefinition)
def sfg_cls_def(self, _: SfgInClassDefinition) -> set[SfgHeaderInclude]:
return set()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment