Skip to content
Snippets Groups Projects
class_composer.py 10.68 KiB
from __future__ import annotations
from typing import Sequence
from itertools import takewhile, dropwhile
import numpy as np

from pystencils.types import create_type

from ..context import SfgContext, SfgCursor
from ..lang import (
    VarLike,
    ExprLike,
    asvar,
    SfgVar,
)

from ..ir import (
    SfgCallTreeNode,
    SfgClass,
    SfgConstructor,
    SfgMethod,
    SfgMemberVariable,
    SfgClassKeyword,
    SfgVisibility,
    SfgVisibilityBlock,
    SfgEntityDecl,
    SfgEntityDef,
    SfgClassBody,
)
from ..exceptions import SfgException

from .mixin import SfgComposerMixIn
from .basic_composer import (
    make_sequence,
    SequencerArg,
    SfgFunctionSequencerBase,
)


class SfgMethodSequencer(SfgFunctionSequencerBase):
    def __init__(self, cursor: SfgCursor, name: str) -> None:
        super().__init__(cursor, name)

        self._const: bool = False
        self._static: bool = False
        self._virtual: bool = False
        self._override: bool = False

        self._tree: SfgCallTreeNode

    def const(self):
        """Mark this method as ``const``."""
        self._const = True
        return self

    def static(self):
        """Mark this method as ``static``."""
        self._static = True
        return self

    def virtual(self):
        """Mark this method as ``virtual``."""
        self._virtual = True
        return self

    def override(self):
        """Mark this method as ``override``."""
        self._override = True
        return self

    def __call__(self, *args: SequencerArg):
        self._tree = make_sequence(*args)
        return self

    def _resolve(self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock):
        method = SfgMethod(
            self._name,
            cls,
            self._tree,
            return_type=self._return_type,
            inline=self._inline,
            const=self._const,
            static=self._static,
            constexpr=self._constexpr,
            virtual=self._virtual,
            override=self._override,
            attributes=self._attributes,
            required_params=self._params,
        )
        cls.add_member(method, vis_block.visibility)

        if self._inline:
            vis_block.elements.append(SfgEntityDef(method))
        else:
            vis_block.elements.append(SfgEntityDecl(method))
            ctx._cursor.write_impl(SfgEntityDef(method))


class SfgClassComposer(SfgComposerMixIn):
    """Composer for classes and structs.


    This class cannot be instantiated on its own but must be mixed in with
    :class:`SfgBasicComposer`.
    Its interface is exposed by :class:`SfgComposer`.
    """

    class VisibilityBlockSequencer:
        """Represent a visibility block in the composer syntax.

        Returned by `private`, `public`, and `protected`.
        """

        def __init__(self, visibility: SfgVisibility):
            self._visibility = visibility
            self._args: tuple[
                SfgMethodSequencer
                | SfgClassComposer.ConstructorBuilder
                | VarLike
                | str,
                ...,
            ]

        def __call__(
            self,
            *args: (
                SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str
            ),
        ):
            self._args = args
            return self

        def _resolve(self, ctx: SfgContext, cls: SfgClass) -> SfgVisibilityBlock:
            vis_block = SfgVisibilityBlock(self._visibility)
            for arg in self._args:
                match arg:
                    case SfgMethodSequencer() | SfgClassComposer.ConstructorBuilder():
                        arg._resolve(ctx, cls, vis_block)
                    case str():
                        vis_block.elements.append(arg)
                    case _:
                        var = asvar(arg)
                        member_var = SfgMemberVariable(var.name, var.dtype, cls)
                        cls.add_member(member_var, vis_block.visibility)
                        vis_block.elements.append(SfgEntityDef(member_var))
            return vis_block

    class ConstructorBuilder:
        """Composer syntax for constructor building.

        Returned by `constructor`.
        """

        def __init__(self, *params: VarLike):
            self._params = list(asvar(p) for p in params)
            self._initializers: list[tuple[SfgVar | str, tuple[ExprLike, ...]]] = []
            self._body: str | None = None

        def add_param(self, param: VarLike, at: int | None = None):
            if at is None:
                self._params.append(asvar(param))
            else:
                self._params.insert(at, asvar(param))

        @property
        def parameters(self) -> list[SfgVar]:
            return self._params

        def init(self, var: VarLike | str):
            """Add an initialization expression to the constructor's initializer list."""

            member = var if isinstance(var, str) else asvar(var)

            def init_sequencer(*args: ExprLike):
                self._initializers.append((member, args))
                return self

            return init_sequencer

        def body(self, body: str):
            """Define the constructor body"""
            if self._body is not None:
                raise SfgException("Multiple definitions of constructor body.")
            self._body = body
            return self

        def _resolve(
            self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock
        ):
            ctor = SfgConstructor(
                cls,
                parameters=self._params,
                initializers=self._initializers,
                body=self._body if self._body is not None else "",
            )

            cls.add_member(ctor, vis_block.visibility)
            vis_block.elements.append(SfgEntityDef(ctor))

    def klass(self, class_name: str, bases: Sequence[str] = ()):
        """Create a class and add it to the underlying context.

        Args:
            class_name: Name of the class
            bases: List of base classes
        """
        return self._class(class_name, SfgClassKeyword.CLASS, bases)

    def struct(self, class_name: str, bases: Sequence[str] = ()):
        """Create a struct and add it to the underlying context.

        Args:
            class_name: Name of the struct
            bases: List of base classes
        """
        return self._class(class_name, SfgClassKeyword.STRUCT, bases)

    def numpy_struct(self, name: str, dtype: np.dtype, add_constructor: bool = True):
        """Add a numpy structured data type as a C++ struct

        Returns:
            The created class object
        """
        return self._struct_from_numpy_dtype(name, dtype, add_constructor)

    @property
    def public(self) -> SfgClassComposer.VisibilityBlockSequencer:
        """Create a `public` visibility block in a class body"""
        return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PUBLIC)

    @property
    def protected(self) -> SfgClassComposer.VisibilityBlockSequencer:
        """Create a `protected` visibility block in a class or struct body"""
        return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PROTECTED)

    @property
    def private(self) -> SfgClassComposer.VisibilityBlockSequencer:
        """Create a `private` visibility block in a class or struct body"""
        return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PRIVATE)

    def constructor(self, *params: VarLike):
        """In a class or struct body or visibility block, add a constructor.

        Args:
            params: List of constructor parameters
        """
        return SfgClassComposer.ConstructorBuilder(*params)

    def method(self, name: str):
        """In a class or struct body or visibility block, add a method.
        The usage is similar to :any:`SfgBasicComposer.function`.

        Args:
            name: The method name
        """

        seq = SfgMethodSequencer(self._cursor, name)
        if self._ctx.impl_file is None:
            seq.inline()
        return seq

    #   INTERNALS

    def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]):
        #   TODO: Return a `CppClass` instance representing the generated class

        if self._cursor.get_entity(class_name) is not None:
            raise ValueError(
                f"Another entity with name {class_name} already exists in the current namespace."
            )

        cls = SfgClass(
            class_name,
            self._cursor.current_namespace,
            class_keyword=keyword,
            bases=bases,
        )
        self._cursor.add_entity(cls)

        def sequencer(
            *args: (
                SfgClassComposer.VisibilityBlockSequencer
                | SfgMethodSequencer
                | SfgClassComposer.ConstructorBuilder
                | VarLike
                | str
            ),
        ):
            default_vis_sequencer = SfgClassComposer.VisibilityBlockSequencer(
                SfgVisibility.DEFAULT
            )

            def argfilter(arg):
                return not isinstance(arg, SfgClassComposer.VisibilityBlockSequencer)

            default_vis_args = takewhile(
                argfilter,
                args,
            )
            default_block = default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls)  # type: ignore
            vis_blocks: list[SfgVisibilityBlock] = []

            for arg in dropwhile(argfilter, args):
                if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer):
                    vis_blocks.append(arg._resolve(self._ctx, cls))
                else:
                    raise SfgException(
                        "Composer Syntax Error: "
                        "Cannot add members with default visibility after a visibility block."
                    )

            self._cursor.write_header(SfgClassBody(cls, default_block, vis_blocks))

        return sequencer

    def _struct_from_numpy_dtype(
        self, struct_name: str, dtype: np.dtype, add_constructor: bool = True
    ):
        fields = dtype.fields
        if fields is None:
            raise SfgException(f"Numpy dtype {dtype} is not a structured type.")

        members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = []
        if add_constructor:
            ctor = self.constructor()
            members.append(ctor)

        for member_name, type_info in fields.items():
            member_type = create_type(type_info[0])

            member = SfgVar(member_name, member_type)
            members.append(member)

            if add_constructor:
                arg = SfgVar(f"{member_name}_", member_type)
                ctor.add_param(arg)
                ctor.init(member)(arg)

        return self.struct(
            struct_name,
        )(*members)