From 5c595075fb2c25646444eed1ca2a42e628c685e2 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 17 Oct 2024 17:29:47 +0200
Subject: [PATCH] More frontend updates

 - Add `Ref` type
 - Allow multi-arg `init` in constructor builder
 - Change `CustomGenerator` to take a composer instead of a context.
 - Allow a class to have multiple methods with the same name.
---
 src/pystencilssfg/composer/basic_composer.py  | 23 +++++++++-------
 src/pystencilssfg/composer/class_composer.py  | 13 ++++------
 src/pystencilssfg/composer/custom.py          |  8 ++++--
 src/pystencilssfg/ir/source_components.py     | 14 +++-------
 src/pystencilssfg/lang/__init__.py            |  3 +++
 src/pystencilssfg/lang/types.py               | 26 +++++++++++++++++++
 tests/generator_scripts/expected/Structural.h |  2 +-
 7 files changed, 59 insertions(+), 30 deletions(-)
 create mode 100644 src/pystencilssfg/lang/types.py

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 0938489..35da6c5 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -65,15 +65,15 @@ class SfgNodeBuilder(ABC):
         pass
 
 
-_ExprLike = (str, AugExpr, TypedSymbol)
-ExprLike: TypeAlias = str | AugExpr | TypedSymbol
+_ExprLike = (str, AugExpr, SfgVar, TypedSymbol)
+ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol
 """Things that may act as a C++ expression.
 
 Expressions need not necesserily have a known data type.
 """
 
-_VarLike = (TypedSymbol, AugExpr)
-VarLike: TypeAlias = TypedSymbol | AugExpr
+_VarLike = (TypedSymbol, SfgVar, AugExpr)
+VarLike: TypeAlias = TypedSymbol | SfgVar | AugExpr
 """Things that may act as a variable.
 
 Variables must always define their name *and* data type.
@@ -113,7 +113,7 @@ class SfgBasicComposer(SfgIComposer):
 
     def define(self, *definitions: str):
         """Add custom definitions to the generated header file.
-        
+
         Each string passed to this method will be printed out directly into the generated header file.
 
         :Example:
@@ -138,7 +138,7 @@ class SfgBasicComposer(SfgIComposer):
 
     def namespace(self, namespace: str):
         """Set the inner code namespace. Throws an exception if a namespace was already set.
-        
+
         :Example:
 
             After adding the following to your generator script:
@@ -150,14 +150,15 @@ class SfgBasicComposer(SfgIComposer):
             .. code-block:: C++
 
                 namespace codegen_is_awesome {
-                    /* all generated code */    
+                    /* all generated code */
                 }
         """
         self._ctx.set_namespace(namespace)
 
     def generate(self, generator: CustomGenerator):
         """Invoke a custom code generator with the underlying context."""
-        generator.generate(self._ctx)
+        from .composer import SfgComposer
+        generator.generate(SfgComposer(self))
 
     @property
     def kernels(self) -> SfgKernelNamespace:
@@ -254,7 +255,7 @@ class SfgBasicComposer(SfgIComposer):
         if self._ctx.get_function(name) is not None:
             raise ValueError(f"Function {name} already exists.")
 
-        def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
+        def sequencer(*args: SequencerArg):
             tree = make_sequence(*args)
             func = SfgFunction(name, tree)
             self._ctx.add_function(func)
@@ -663,6 +664,8 @@ def struct_from_numpy_dtype(
 
 def _asvar(var: VarLike) -> SfgVar:
     match var:
+        case SfgVar():
+            return var
         case AugExpr():
             return var.as_variable()
         case TypedSymbol():
@@ -683,6 +686,8 @@ def _depends(expr: ExprLike) -> set[SfgVar]:
     match expr:
         case None | str():
             return set()
+        case SfgVar():
+            return {expr}
         case TypedSymbol():
             return {_asvar(expr)}
         case AugExpr():
diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py
index 9b8fcfb..ed081dc 100644
--- a/src/pystencilssfg/composer/class_composer.py
+++ b/src/pystencilssfg/composer/class_composer.py
@@ -1,11 +1,8 @@
 from __future__ import annotations
 from typing import Sequence
 
-from pystencils import TypedSymbol
 from pystencils.types import PsCustomType, UserTypeSpec
 
-from ..lang import AugExpr
-from ..ir import SfgCallTreeNode
 from ..ir.source_components import (
     SfgClass,
     SfgClassMember,
@@ -21,11 +18,11 @@ from ..exceptions import SfgException
 
 from .mixin import SfgComposerMixIn
 from .basic_composer import (
-    SfgNodeBuilder,
     make_sequence,
     _VarLike,
     VarLike,
     ExprLike,
+    SequencerArg,
     _asvar,
 )
 
@@ -79,8 +76,8 @@ class SfgClassComposer(SfgComposerMixIn):
         def init(self, var: VarLike):
             """Add an initialization expression to the constructor's initializer list."""
 
-            def init_sequencer(expr: ExprLike):
-                expr = str(expr)
+            def init_sequencer(*args: ExprLike):
+                expr = ", ".join(str(arg) for arg in args)
                 initializer = f"{_asvar(var)}{{ {expr} }}"
                 self._initializers.append(initializer)
                 return self
@@ -159,7 +156,7 @@ class SfgClassComposer(SfgComposerMixIn):
             const: Whether or not the method is const-qualified.
         """
 
-        def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
+        def sequencer(*args: SequencerArg):
             tree = make_sequence(*args)
             return SfgMethod(
                 name,
@@ -221,7 +218,7 @@ class SfgClassComposer(SfgComposerMixIn):
         arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str,
     ) -> SfgClassMember:
         match arg:
-            case AugExpr() | TypedSymbol():
+            case _ if isinstance(arg, _VarLike):
                 var = _asvar(arg)
                 return SfgMemberVariable(var.name, var.dtype)
             case str():
diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py
index fa53d6a..7df364c 100644
--- a/src/pystencilssfg/composer/custom.py
+++ b/src/pystencilssfg/composer/custom.py
@@ -1,5 +1,9 @@
+from __future__ import annotations
 from abc import ABC, abstractmethod
-from ..context import SfgContext
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from .composer import SfgComposer
 
 
 class CustomGenerator(ABC):
@@ -7,4 +11,4 @@ class CustomGenerator(ABC):
     `SfgComposer.generate`."""
 
     @abstractmethod
-    def generate(self, ctx: SfgContext) -> None: ...
+    def generate(self, sfg: SfgComposer) -> None: ...
diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py
index 859f926..c8a72df 100644
--- a/src/pystencilssfg/ir/source_components.py
+++ b/src/pystencilssfg/ir/source_components.py
@@ -520,7 +520,7 @@ class SfgClass:
 
         self._definitions: list[SfgInClassDefinition] = []
         self._constructors: list[SfgConstructor] = []
-        self._methods: dict[str, SfgMethod] = dict()
+        self._methods: list[SfgMethod] = []
         self._member_vars: dict[str, SfgMemberVariable] = dict()
 
     @property
@@ -599,10 +599,10 @@ class SfgClass:
     ) -> Generator[SfgMethod, None, None]:
         if visibility is not None:
             yield from filter(
-                lambda m: m.visibility == visibility, self._methods.values()
+                lambda m: m.visibility == visibility, self._methods
             )
         else:
-            yield from self._methods.values()
+            yield from self._methods
 
     # PRIVATE
 
@@ -624,16 +624,10 @@ class SfgClass:
         self._definitions.append(definition)
 
     def _add_constructor(self, constr: SfgConstructor):
-        #   TODO: Check for signature conflicts?
         self._constructors.append(constr)
 
     def _add_method(self, method: SfgMethod):
-        if method.name in self._methods:
-            raise SfgException(
-                f"Duplicate method name {method.name} in class {self._class_name}"
-            )
-
-        self._methods[method.name] = method
+        self._methods.append(method)
 
     def _add_member_variable(self, variable: SfgMemberVariable):
         if variable.name in self._member_vars:
diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py
index 543d309..99661af 100644
--- a/src/pystencilssfg/lang/__init__.py
+++ b/src/pystencilssfg/lang/__init__.py
@@ -6,10 +6,13 @@ from .expressions import (
     SrcVector,
 )
 
+from .types import Ref
+
 __all__ = [
     "DependentExpression",
     "AugExpr",
     "IFieldExtraction",
     "SrcField",
     "SrcVector",
+    "Ref",
 ]
diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py
new file mode 100644
index 0000000..6f23160
--- /dev/null
+++ b/src/pystencilssfg/lang/types.py
@@ -0,0 +1,26 @@
+from typing import Any
+from pystencils.types import PsType
+
+
+class Ref(PsType):
+    """C++ reference type."""
+
+    __match_args__ = "base_type"
+
+    def __init__(self, base_type: PsType, const: bool = False):
+        super().__init__(False)
+        self._base_type = base_type
+
+    def __args__(self) -> tuple[Any, ...]:
+        return (self.base_type,)
+
+    @property
+    def base_type(self) -> PsType:
+        return self._base_type
+
+    def c_string(self) -> str:
+        base_str = self.base_type.c_string()
+        return base_str + "&"
+
+    def __repr__(self) -> str:
+        return f"Ref({repr(self.base_type)})"
diff --git a/tests/generator_scripts/expected/Structural.h b/tests/generator_scripts/expected/Structural.h
index 45ffcf6..0eb1e25 100644
--- a/tests/generator_scripts/expected/Structural.h
+++ b/tests/generator_scripts/expected/Structural.h
@@ -15,4 +15,4 @@ namespace awesome {
 #define PI 3.1415
 using namespace std;
 
-}
+} // namespace awesome
-- 
GitLab