From ed11c419a61e6c46145e99eaeeb14293cc184eda Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 13 Dec 2023 09:59:51 +0100
Subject: [PATCH] updated composer and fixed a few bugs

---
 integration/test_class_composer.py     |  12 +--
 src/pystencilssfg/composer.py          | 137 +++++++++++++------------
 src/pystencilssfg/source_components.py |  10 +-
 3 files changed, 84 insertions(+), 75 deletions(-)

diff --git a/integration/test_class_composer.py b/integration/test_class_composer.py
index 4bb9860..07039b7 100644
--- a/integration/test_class_composer.py
+++ b/integration/test_class_composer.py
@@ -35,6 +35,12 @@ with SourceFileGenerator(sfg_config) as ctx:
     c.klass("MyClass", bases=("MyBaseClass",))(
         # class body sequencer
 
+        c.constructor(SrcObject("a", "int"))
+        .init("a_(a)")
+        .body(
+            'cout << "Hi!" << endl;'
+        ),
+
         c.private(
             c.var("a_", "int"),
 
@@ -43,12 +49,6 @@ with SourceFileGenerator(sfg_config) as ctx:
             )
         ),
 
-        c.constructor(SrcObject("a", "int"))
-        .init("a_(a)")
-        .body(
-            'cout << "Hi!" << endl;'
-        ),
-
         c.public(
             "using xtype = uint8_t;"
         )
diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py
index e12f1d9..c46ba07 100644
--- a/src/pystencilssfg/composer.py
+++ b/src/pystencilssfg/composer.py
@@ -2,7 +2,6 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, Sequence
 from abc import ABC, abstractmethod
 import numpy as np
-from functools import partial
 
 from pystencils import Field
 from pystencils.astnodes import KernelFunction
@@ -31,6 +30,7 @@ from .source_components import (
     SfgMemberVariable,
     SfgClassKeyword,
     SfgVisibility,
+    SfgVisibilityBlock,
 )
 from .source_concepts import SrcObject, SrcField, TypedSymbolOrObject, SrcVector
 from .types import cpp_typename, SrcType
@@ -338,66 +338,47 @@ class SfgClassComposer:
     def __init__(self, ctx: SfgContext):
         self._ctx = ctx
 
-    class PartialMember:
-        def __init__(self, member_type: type[SfgClassMember], *args, **kwargs):
-            assert issubclass(member_type, SfgClassMember)
-
-            self._type = member_type
-            self._partial = partial(member_type, *args, **kwargs)
-
-        @property
-        def member_type(self):
-            return self._type
-
-        def resolve(self, cls: SfgClass, visibility: SfgVisibility) -> SfgClassMember:
-            return self._partial(cls=cls, visibility=visibility)
-
     class VisibilityContext:
         def __init__(self, visibility: SfgVisibility):
-            self._vis = visibility
-            self._partial_members: list[SfgClassComposer.PartialMember] = []
+            self._vis_block = SfgVisibilityBlock(visibility)
 
         def members(self):
-            yield from self._partial_members
-
-        def __call__(self, *args: SfgClassComposer.PartialMember | SrcObject | str):
+            yield from self._vis_block.members()
+
+        def __call__(
+            self,
+            *args: SfgClassMember
+            | SfgClassComposer.ConstructorBuilder
+            | SrcObject
+            | str,
+        ):
             for arg in args:
-                if isinstance(arg, SrcObject):
-                    self._partial_members.append(
-                        SfgClassComposer.PartialMember(
-                            SfgMemberVariable, name=arg.name, dtype=arg.dtype
-                        )
-                    )
-                elif isinstance(arg, str):
-                    self._partial_members.append(
-                        SfgClassComposer.PartialMember(SfgInClassDefinition, text=arg)
-                    )
-                else:
-                    self._partial_members.append(arg)
+                self._vis_block.append_member(SfgClassComposer._resolve_member(arg))
 
             return self
 
-        def resolve(self, cls: SfgClass) -> list[SfgClassMember]:
-            return [
-                part.resolve(cls=cls, visibility=self._vis)
-                for part in self._partial_members
-            ]
+        def resolve(self, cls: SfgClass) -> None:
+            cls.append_visibility_block(self._vis_block)
 
     class ConstructorBuilder:
         def __init__(self, *params: SrcObject):
             self._params = params
             self._initializers: list[str] = []
+            self._body = ""
 
         def init(self, initializer: str) -> SfgClassComposer.ConstructorBuilder:
             self._initializers.append(initializer)
             return self
 
         def body(self, body: str):
-            return SfgClassComposer.PartialMember(
-                SfgConstructor,
+            self._body = body
+            return self
+
+        def resolve(self) -> SfgConstructor:
+            return SfgConstructor(
                 parameters=self._params,
                 initializers=self._initializers,
-                body=body,
+                body=self._body,
             )
 
     def klass(self, class_name: str, bases: Sequence[str] = ()):
@@ -410,12 +391,16 @@ class SfgClassComposer:
     def public(self) -> SfgClassComposer.VisibilityContext:
         return SfgClassComposer.VisibilityContext(SfgVisibility.PUBLIC)
 
+    @property
+    def protected(self) -> SfgClassComposer.VisibilityContext:
+        return SfgClassComposer.VisibilityContext(SfgVisibility.PROTECTED)
+
     @property
     def private(self) -> SfgClassComposer.VisibilityContext:
         return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE)
 
     def var(self, name: str, dtype: SrcType):
-        return SfgClassComposer.PartialMember(SfgMemberVariable, name=name, dtype=dtype)
+        return SfgMemberVariable(name, dtype)
 
     def constructor(self, *params):
         return SfgClassComposer.ConstructorBuilder(*params)
@@ -429,13 +414,8 @@ class SfgClassComposer:
     ):
         def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
             tree = make_sequence(*args)
-            return SfgClassComposer.PartialMember(
-                SfgMethod,
-                name=name,
-                tree=tree,
-                return_type=returns,
-                inline=inline,
-                const=const,
+            return SfgMethod(
+                name, tree, return_type=returns, inline=inline, const=const
             )
 
         return sequencer
@@ -449,22 +429,53 @@ class SfgClassComposer:
         cls = SfgClass(class_name, class_keyword=keyword, bases=bases)
         self._ctx.add_class(cls)
 
-        def sequencer(*args):
-            default_context = SfgClassComposer.VisibilityContext(SfgVisibility.DEFAULT)
+        def sequencer(
+            *args: SfgClassComposer.VisibilityContext
+            | SfgClassMember
+            | SfgClassComposer.ConstructorBuilder
+            | SrcObject
+            | str,
+        ):
+            default_ended = False
+
             for arg in args:
                 if isinstance(arg, SfgClassComposer.VisibilityContext):
-                    for member in arg.resolve(cls):
-                        cls.add_member(member)
-                elif isinstance(arg, (SfgClassComposer.PartialMember, SrcObject, str)):
-                    default_context(arg)
+                    default_ended = True
+                    arg.resolve(cls)
+                elif isinstance(
+                    arg,
+                    (
+                        SfgClassMember,
+                        SfgClassComposer.ConstructorBuilder,
+                        SrcObject,
+                        str,
+                    ),
+                ):
+                    if default_ended:
+                        raise SfgException(
+                            "Composer Syntax Error: "
+                            "Cannot add members with default visibility after a visibility block."
+                        )
+                    else:
+                        cls.default.append_member(self._resolve_member(arg))
                 else:
                     raise SfgException(f"{arg} is not a valid class member.")
 
-            for member in default_context.resolve(cls):
-                cls.add_member(member)
-
         return sequencer
 
+    @staticmethod
+    def _resolve_member(
+        arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str,
+    ):
+        if isinstance(arg, SrcObject):
+            return SfgMemberVariable(arg.name, arg.dtype)
+        elif isinstance(arg, str):
+            return SfgInClassDefinition(arg)
+        elif isinstance(arg, SfgClassComposer.ConstructorBuilder):
+            return arg.resolve()
+        else:
+            return arg
+
 
 def struct_from_numpy_dtype(
     struct_name: str, dtype: np.dtype, add_constructor: bool = True
@@ -481,22 +492,16 @@ def struct_from_numpy_dtype(
     for member_name, type_info in fields.items():
         member_type = SrcType(cpp_typename(type_info[0]))
 
-        member = SfgMemberVariable(
-            member_name, member_type, cls, visibility=SfgVisibility.DEFAULT
-        )
+        member = SfgMemberVariable(member_name, member_type)
 
         arg = SrcObject(f"{member_name}_", member_type)
 
-        cls._add_member_variable(member)
+        cls.default.append_member(member)
 
         constr_params.append(arg)
         constr_inits.append(f"{member}({arg})")
 
     if add_constructor:
-        cls._add_constructor(
-            SfgConstructor(
-                cls, constr_params, constr_inits, visibility=SfgVisibility.DEFAULT
-            )
-        )
+        cls.default.append_member(SfgConstructor(constr_params, constr_inits))
 
     return cls
diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py
index b398c3d..e753b37 100644
--- a/src/pystencilssfg/source_components.py
+++ b/src/pystencilssfg/source_components.py
@@ -233,7 +233,7 @@ class SfgClassKeyword(Enum):
 
 
 class SfgClassMember(ABC):
-    def __init__(self):
+    def __init__(self) -> None:
         self._cls: SfgClass | None = None
         self._visibility: SfgVisibility | None = None
 
@@ -372,8 +372,8 @@ class SfgClass:
     ### Adding members to classes
 
     Members are never added directly to a class. Instead, they are added to
-    a SfgVisibilityBlock which defines their syntactic position and visibility modifier
-    in the code.
+    an [SfgVisibilityBlock][pystencilssfg.source_components.SfgVisibilityBlock]
+    which defines their syntactic position and visibility modifier in the code.
     At the top of every class, there is a default visibility block
     accessible through the `default` property.
     To add members with custom visibility, create a new SfgVisibilityBlock,
@@ -425,6 +425,10 @@ class SfgClass:
         return self._default_block
 
     def append_visibility_block(self, block: SfgVisibilityBlock):
+        if block.visibility == SfgVisibility.DEFAULT:
+            raise SfgException(
+                "Can't add another block with DEFAULT visibility to a class. Use `.default` instead.")
+
         block._bind(self)
         for m in block.members():
             self._add_member(m, block.visibility)
-- 
GitLab