Skip to content
Snippets Groups Projects
Commit af0e3408 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix header collection

parent cb4a449d
No related branches found
No related tags found
No related merge requests found
Pipeline #58337 passed
from typing import Generator, Sequence from typing import Generator, Sequence
from .configuration import SfgCodeStyle from .configuration import SfgCodeStyle
from .visitors import CollectIncludes
from .source_components import ( from .source_components import (
SfgHeaderInclude, SfgHeaderInclude,
SfgKernelNamespace, SfgKernelNamespace,
...@@ -156,9 +155,6 @@ class SfgContext: ...@@ -156,9 +155,6 @@ class SfgContext:
self._functions[func.name] = func self._functions[func.name] = func
self._declarations_ordered.append(func) self._declarations_ordered.append(func)
for incl in CollectIncludes().visit(func):
self.add_include(incl)
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
# Classes # Classes
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
...@@ -176,9 +172,6 @@ class SfgContext: ...@@ -176,9 +172,6 @@ class SfgContext:
self._classes[cls.class_name] = cls self._classes[cls.class_name] = cls
self._declarations_ordered.append(cls) self._declarations_ordered.append(cls)
for incl in CollectIncludes().visit(cls):
self.add_include(incl)
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
# Declarations in order of addition # Declarations in order of addition
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
......
...@@ -2,6 +2,7 @@ from os import path, makedirs ...@@ -2,6 +2,7 @@ from os import path, makedirs
from ..configuration import SfgOutputSpec from ..configuration import SfgOutputSpec
from ..context import SfgContext from ..context import SfgContext
from .prepare import prepare_context
from .printers import SfgHeaderPrinter, SfgImplPrinter from .printers import SfgHeaderPrinter, SfgImplPrinter
from .clang_format import invoke_clang_format from .clang_format import invoke_clang_format
...@@ -20,10 +21,12 @@ class HeaderSourcePairEmitter: ...@@ -20,10 +21,12 @@ class HeaderSourcePairEmitter:
def output_files(self) -> tuple[str, str]: def output_files(self) -> tuple[str, str]:
return ( return (
path.join(self._output_directory, self._header_filename), path.join(self._output_directory, self._header_filename),
path.join(self._output_directory, self._impl_filename) path.join(self._output_directory, self._impl_filename),
) )
def write_files(self, ctx: SfgContext): def write_files(self, ctx: SfgContext):
ctx = prepare_context(ctx)
header_printer = SfgHeaderPrinter(ctx, self._ospec) header_printer = SfgHeaderPrinter(ctx, self._ospec)
impl_printer = SfgImplPrinter(ctx, self._ospec) impl_printer = SfgImplPrinter(ctx, self._ospec)
...@@ -35,8 +38,8 @@ class HeaderSourcePairEmitter: ...@@ -35,8 +38,8 @@ class HeaderSourcePairEmitter:
makedirs(self._output_directory, exist_ok=True) makedirs(self._output_directory, exist_ok=True)
with open(self._ospec.get_header_filepath(), 'w') as headerfile: with open(self._ospec.get_header_filepath(), "w") as headerfile:
headerfile.write(header) headerfile.write(header)
with open(self._ospec.get_impl_filepath(), 'w') as cppfile: with open(self._ospec.get_impl_filepath(), "w") as cppfile:
cppfile.write(impl) cppfile.write(impl)
from ..context import SfgContext
from ..visitors import CollectIncludes
def prepare_context(ctx: SfgContext):
"""Prepares a populated context for printing. Make sure to run this function on the
[SfgContext][pystencilssfg.SfgContext] before passing it to a printer.
Steps:
- Collection of includes: All defined functions and classes are traversed to collect all required
header includes
"""
# Collect all includes
required_includes = CollectIncludes().visit(ctx)
for incl in required_includes:
ctx.add_include(incl)
return ctx
...@@ -440,9 +440,14 @@ class SfgClass: ...@@ -440,9 +440,14 @@ class SfgClass:
def members( def members(
self, visibility: SfgVisibility | None = None self, visibility: SfgVisibility | None = None
) -> Generator[SfgClassMember, None, None]: ) -> Generator[SfgClassMember, None, None]:
yield from chain.from_iterable( if visibility is None:
b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks) yield from chain.from_iterable(
) b.members() for b in self._blocks
)
else:
yield from chain.from_iterable(
b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
)
def definitions( def definitions(
self, visibility: SfgVisibility | None = None self, visibility: SfgVisibility | None = None
......
...@@ -7,7 +7,14 @@ from functools import reduce ...@@ -7,7 +7,14 @@ from functools import reduce
from .dispatcher import visitor from .dispatcher import visitor
from ..exceptions import SfgException from ..exceptions import SfgException
from ..tree import SfgCallTreeNode from ..tree import SfgCallTreeNode
from ..source_components import SfgFunction, SfgClass, SfgConstructor, SfgMemberVariable from ..source_components import (
SfgFunction,
SfgClass,
SfgConstructor,
SfgMemberVariable,
SfgInClassDefinition,
)
from ..context import SfgContext
if TYPE_CHECKING: if TYPE_CHECKING:
from ..source_components import SfgHeaderInclude from ..source_components import SfgHeaderInclude
...@@ -18,10 +25,23 @@ class CollectIncludes: ...@@ -18,10 +25,23 @@ class CollectIncludes:
def visit(self, obj: object) -> set[SfgHeaderInclude]: def visit(self, obj: object) -> set[SfgHeaderInclude]:
raise SfgException(f"Can't collect includes from object of type {type(obj)}") raise SfgException(f"Can't collect includes from object of type {type(obj)}")
@visit.case(SfgContext)
def context(self, ctx: SfgContext) -> set[SfgHeaderInclude]:
includes = set()
for func in ctx.functions():
includes |= self.visit(func)
for cls in ctx.classes():
includes |= self.visit(cls)
return includes
@visit.case(SfgCallTreeNode) @visit.case(SfgCallTreeNode)
def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]: def tree_node(self, node: SfgCallTreeNode) -> set[SfgHeaderInclude]:
return reduce( return reduce(
lambda accu, child: accu | self.visit(child), node.children, node.required_includes lambda accu, child: accu | self.visit(child),
node.children,
node.required_includes,
) )
@visit.case(SfgFunction) @visit.case(SfgFunction)
...@@ -43,3 +63,7 @@ class CollectIncludes: ...@@ -43,3 +63,7 @@ class CollectIncludes:
@visit.case(SfgMemberVariable) @visit.case(SfgMemberVariable)
def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]: def sfg_member_var(self, var: SfgMemberVariable) -> set[SfgHeaderInclude]:
return var.required_includes return var.required_includes
@visit.case(SfgInClassDefinition)
def sfg_cls_def(self, _: SfgInClassDefinition) -> set[SfgHeaderInclude]:
return set()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment