From d4c41740770e15ada8ba81144a8b7a941a8eeaff Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 3 Dec 2023 19:04:48 +0100
Subject: [PATCH] started on header printer

---
 integration/test_classes.py                  |  3 +
 src/pystencilssfg/printing/header_printer.py | 81 ++++++++++++++++++++
 src/pystencilssfg/source_components.py       | 35 ++++++---
 3 files changed, 109 insertions(+), 10 deletions(-)
 create mode 100644 src/pystencilssfg/printing/header_printer.py

diff --git a/integration/test_classes.py b/integration/test_classes.py
index 40d7098..192807b 100644
--- a/integration/test_classes.py
+++ b/integration/test_classes.py
@@ -29,18 +29,21 @@ with SourceFileGenerator(sfg_config) as sfg:
     cls.add_method(SfgMethod(
         "callKernel",
         sfg.call(khandle),
+        cls,
         visibility=SfgVisibility.PUBLIC
     ))
 
     cls.add_member_variable(
         SfgMemberVariable(
             "stuff", "std::vector< int >",
+            cls,
             SfgVisibility.PRIVATE
         )
     )
 
     cls.add_constructor(
         SfgConstructor(
+            cls,
             [SrcObject("std::vector< int > &", "stuff")],
             ["stuff_(stuff)"],
             visibility=SfgVisibility.PUBLIC
diff --git a/src/pystencilssfg/printing/header_printer.py b/src/pystencilssfg/printing/header_printer.py
new file mode 100644
index 0000000..e49ba9c
--- /dev/null
+++ b/src/pystencilssfg/printing/header_printer.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+from textwrap import indent
+from itertools import chain, repeat
+
+from ..context import SfgContext
+from ..configuration import SfgOutputSpec
+from ..visitors import visitor
+from ..exceptions import SfgException
+
+from ..source_components import (
+    SfgEmptyLines, SfgHeaderInclude
+)
+
+
+def interleave(*iters):
+    try:
+        for iter in iters:
+            yield next(iter)
+    except StopIteration:
+        pass
+
+
+class SfgHeaderPrinter:
+    def __init__(self, output_spec: SfgOutputSpec):
+        self._output_spec = output_spec
+
+    def code_string(self, ctx: SfgContext) -> str:
+        return self.visit(ctx)
+
+    @visitor
+    def visit(self, obj: object) -> str:
+        raise SfgException(f"Can't print object of type {type(obj)}")
+
+    @visit.case(SfgEmptyLines)
+    def emptylines(self, el: SfgEmptyLines) -> str:
+        return "\n" * el.lines
+
+    @visit.case(str)
+    def string(self, s: str) -> str:
+        return s
+
+    @visit.case(SfgHeaderInclude)
+    def include(self, incl: SfgHeaderInclude) -> str:
+        if incl.system_header:
+            return f"#include <{incl.file}>"
+        else:
+            return f'#include "{incl.file}"'
+
+    @visit.case(SfgContext)
+    def frame(self, ctx: SfgContext) -> str:
+        code = ""
+
+        if ctx.prelude_comment:
+            code += "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n"
+
+        code += "\n#pragma once\n\n"
+
+        includes = filter(lambda incl: not incl.private, ctx.includes())
+        code += "\n".join(self.visit(incl) for incl in includes)
+        code += "\n"
+
+        fq_namespace = ctx.fully_qualified_namespace
+        if fq_namespace is not None:
+            code += f"namespace {fq_namespace} {{\n"
+
+        parts = interleave(
+            chain(
+                ctx.definitions(),
+                ctx.classes(),
+                ctx.functions()
+            ),
+            repeat(SfgEmptyLines(1))
+        )
+
+        code += "".join(self.visit(p) for p in parts)
+
+        if fq_namespace is not None:
+            code += f"}} \\ namespace {fq_namespace}\n"
+
+        return code
diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py
index d0bf5c2..b921e36 100644
--- a/src/pystencilssfg/source_components.py
+++ b/src/pystencilssfg/source_components.py
@@ -17,6 +17,15 @@ if TYPE_CHECKING:
     from .tree import SfgCallTreeNode
 
 
+class SfgEmptyLines:
+    def __init__(self, lines: int):
+        self._lines = lines
+
+    @property
+    def lines(self) -> int:
+        return self._lines
+
+
 class SfgHeaderInclude:
     def __init__(
         self, header_file: str, system_header: bool = False, private: bool = False
@@ -25,6 +34,10 @@ class SfgHeaderInclude:
         self._system_header = system_header
         self._private = private
 
+    @property
+    def file(self) -> str:
+        return self._header_file
+
     @property
     def system_header(self):
         return self._system_header
@@ -33,12 +46,6 @@ class SfgHeaderInclude:
     def private(self):
         return self._private
 
-    def get_code(self):
-        if self._system_header:
-            return f"#include <{self._header_file}>"
-        else:
-            return f'#include "{self._header_file}"'
-
     def __hash__(self) -> int:
         return hash((self._header_file, self._system_header, self._private))
 
@@ -219,9 +226,14 @@ class SfgClassKeyword(Enum):
 
 
 class SfgClassMember(ABC):
-    def __init__(self, visibility: SfgVisibility):
+    def __init__(self, cls: SfgClass, visibility: SfgVisibility):
+        self._cls = cls
         self._visibility = visibility
 
+    @property
+    def owning_class(self) -> SfgClass:
+        return self._cls
+
     @property
     def visibility(self) -> SfgVisibility:
         return self._visibility
@@ -232,10 +244,11 @@ class SfgMemberVariable(SrcObject, SfgClassMember):
         self,
         name: str,
         type: SrcType,
+        cls: SfgClass,
         visibility: SfgVisibility = SfgVisibility.PRIVATE,
     ):
         SrcObject.__init__(self, type, name)
-        SfgClassMember.__init__(self, visibility)
+        SfgClassMember.__init__(self, cls, visibility)
 
 
 class SfgMethod(SfgFunction, SfgClassMember):
@@ -243,21 +256,23 @@ class SfgMethod(SfgFunction, SfgClassMember):
         self,
         name: str,
         tree: SfgCallTreeNode,
+        cls: SfgClass,
         visibility: SfgVisibility = SfgVisibility.PUBLIC,
     ):
         SfgFunction.__init__(self, name, tree)
-        SfgClassMember.__init__(self, visibility)
+        SfgClassMember.__init__(self, cls, visibility)
 
 
 class SfgConstructor(SfgClassMember):
     def __init__(
         self,
+        cls: SfgClass,
         parameters: Sequence[SrcObject] = (),
         initializers: Sequence[str] = (),
         body: str = "",
         visibility: SfgVisibility = SfgVisibility.PUBLIC,
     ):
-        SfgClassMember.__init__(self, visibility)
+        SfgClassMember.__init__(self, cls, visibility)
         self._parameters = tuple(parameters)
         self._initializers = tuple(initializers)
         self._body = body
-- 
GitLab