diff --git a/integration/test_classes.py b/integration/test_classes.py index d298a6c775edcbe4886a24ae8ec41eb9a474cf13..7db94d2db5edf6d1a20450c9346a9ec92fd6fb7e 100644 --- a/integration/test_classes.py +++ b/integration/test_classes.py @@ -28,51 +28,41 @@ with SourceFileGenerator(sfg_config) as ctx: khandle = sfg.kernels.create(assignments) cls = SfgClass("MyClass") - cls.add_method(SfgMethod( + cls.default.append_member(SfgMethod( "callKernel", - sfg.call(khandle), - cls, - visibility=SfgVisibility.PUBLIC + sfg.call(khandle) )) - cls.add_method(SfgMethod( + cls.default.append_member(SfgMethod( "inlineConst", sfg.seq( "return -1.0;" ), - cls, - visibility=SfgVisibility.PUBLIC, return_type=SrcType("double"), inline=True, const=True )) - cls.add_method(SfgMethod( + cls.default.append_member(SfgMethod( "awesomeMethod", sfg.seq( "return 2.0f;" ), - cls, - visibility=SfgVisibility.PRIVATE, return_type=SrcType("float"), inline=False, const=True )) - cls.add_member_variable( + cls.default.append_member( SfgMemberVariable( - "stuff", "std::vector< int >", - cls, - SfgVisibility.PRIVATE + "stuff", "std::vector< int >" ) ) - cls.add_constructor( + cls.default.append_member( SfgConstructor( - cls, [SrcObject("stuff", "std::vector< int > &")], - ["stuff_(stuff)"], - visibility=SfgVisibility.PUBLIC + ["stuff_(stuff)"] ) ) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index 08183d0afc98310414ebae116fc0325093cb9037..e12f1d93e5a4f834307e698121008f694f009ed4 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -487,13 +487,13 @@ def struct_from_numpy_dtype( arg = SrcObject(f"{member_name}_", member_type) - cls.add_member_variable(member) + cls._add_member_variable(member) constr_params.append(arg) constr_inits.append(f"{member}({arg})") if add_constructor: - cls.add_constructor( + cls._add_constructor( SfgConstructor( cls, constr_params, constr_inits, visibility=SfgVisibility.DEFAULT ) diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 69c19a8d74e567b1051b0724ec2a6f52e24237a7..ae504f4a3fd64a892788977cb5c62223101767c9 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -23,6 +23,7 @@ from ..source_components import ( SfgMemberVariable, SfgMethod, SfgVisibility, + SfgVisibilityBlock ) @@ -118,19 +119,24 @@ class SfgHeaderPrinter(SfgGeneralPrinter): code += f" : {','.join(cls.base_classes)}\n" code += "{\n" - for visibility in ( - SfgVisibility.DEFAULT, - SfgVisibility.PUBLIC, - SfgVisibility.PRIVATE, - ): - if visibility != SfgVisibility.DEFAULT: - code += f"\n{visibility}:\n" - for member in cls.members(visibility): - code += self._ctx.codestyle.indent(self.visit(member)) + "\n" + + for block in cls.visibility_blocks(): + code += self.visit(block) + "\n" + code += "};\n" return code + @visit.case(SfgVisibilityBlock) + def vis_block(self, block: SfgVisibilityBlock) -> str: + code = "" + if block.visibility != SfgVisibility.DEFAULT: + code += f"{block.visibility}:\n" + code += self._ctx.codestyle.indent( + "\n".join(self.visit(m) for m in block.members()) + ) + return code + @visit.case(SfgInClassDefinition) def sfg_inclassdef(self, definition: SfgInClassDefinition): return definition.text diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index a33c98d5434f4912a4335339318d2e684ab2b67c..b398c3dc049a2375418395a2d0d7709f16544829 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -4,6 +4,7 @@ from abc import ABC from enum import Enum, auto from typing import TYPE_CHECKING, Sequence, Generator from dataclasses import replace +from itertools import chain from pystencils import CreateKernelConfig, create_kernel from pystencils.astnodes import KernelFunction @@ -204,6 +205,7 @@ class SfgFunction: class SfgVisibility(Enum): DEFAULT = auto() PRIVATE = auto() + PROTECTED = auto() PUBLIC = auto() def __str__(self) -> str: @@ -212,6 +214,8 @@ class SfgVisibility(Enum): return "" case SfgVisibility.PRIVATE: return "private" + case SfgVisibility.PROTECTED: + return "protected" case SfgVisibility.PUBLIC: return "public" @@ -229,22 +233,70 @@ class SfgClassKeyword(Enum): class SfgClassMember(ABC): - def __init__(self, cls: SfgClass, visibility: SfgVisibility): - self._cls = cls - self._visibility = visibility + def __init__(self): + self._cls: SfgClass | None = None + self._visibility: SfgVisibility | None = None @property def owning_class(self) -> SfgClass: + if self._cls is None: + raise SfgException(f"{self} is not bound to a class.") return self._cls @property def visibility(self) -> SfgVisibility: + if self._visibility is None: + raise SfgException(f"{self} is not bound to a class and therefore has no visibility.") return self._visibility + @property + def is_bound(self) -> bool: + return self._cls is not None + + def _bind(self, cls: SfgClass, vis: SfgVisibility): + if self.is_bound: + raise SfgException( + f"Binding {self} to class {cls.class_name} failed: " + f"{self} was already bound to {self.owning_class.class_name}" + ) + self._cls = cls + self._vis = vis + + +class SfgVisibilityBlock: + def __init__(self, visibility: SfgVisibility) -> None: + self._vis = visibility + self._members: list[SfgClassMember] = [] + self._cls: SfgClass | None = None + + @property + def visibility(self) -> SfgVisibility: + return self._vis + + def append_member(self, member: SfgClassMember): + if self._cls is not None: + self._cls._add_member(member, self._vis) + self._members.append(member) + + def members(self) -> Generator[SfgClassMember, None, None]: + yield from self._members + + @property + def is_bound(self) -> bool: + return self._cls is not None + + def _bind(self, cls: SfgClass): + if self._cls is not None: + raise SfgException( + f"Binding visibility block to class {cls.class_name} failed: " + f"was already bound to {self._cls.class_name}" + ) + self._cls = cls + class SfgInClassDefinition(SfgClassMember): - def __init__(self, text: str, cls: SfgClass, visibility: SfgVisibility): - SfgClassMember.__init__(self, cls, visibility) + def __init__(self, text: str): + SfgClassMember.__init__(self) self._text = text @property @@ -259,12 +311,10 @@ class SfgMemberVariable(SrcObject, SfgClassMember): def __init__( self, name: str, - dtype: SrcType, - cls: SfgClass, - visibility: SfgVisibility = SfgVisibility.PRIVATE, + dtype: SrcType ): SrcObject.__init__(self, name, dtype) - SfgClassMember.__init__(self, cls, visibility) + SfgClassMember.__init__(self) class SfgMethod(SfgFunction, SfgClassMember): @@ -272,14 +322,12 @@ class SfgMethod(SfgFunction, SfgClassMember): self, name: str, tree: SfgCallTreeNode, - cls: SfgClass, - visibility: SfgVisibility = SfgVisibility.PUBLIC, return_type: SrcType = SrcType("void"), inline: bool = False, const: bool = False, ): SfgFunction.__init__(self, name, tree, return_type=return_type) - SfgClassMember.__init__(self, cls, visibility) + SfgClassMember.__init__(self) self._inline = inline self._const = const @@ -296,13 +344,11 @@ class SfgMethod(SfgFunction, SfgClassMember): class SfgConstructor(SfgClassMember): def __init__( self, - cls: SfgClass, parameters: Sequence[SrcObject] = (), initializers: Sequence[str] = (), - body: str = "", - visibility: SfgVisibility = SfgVisibility.PUBLIC, + body: str = "" ): - SfgClassMember.__init__(self, cls, visibility) + SfgClassMember.__init__(self) self._parameters = tuple(parameters) self._initializers = tuple(initializers) self._body = body @@ -321,6 +367,21 @@ class SfgConstructor(SfgClassMember): class SfgClass: + """Models a C++ class. + + ### Adding members to classes + + Members are never added directly to a class. Instead, they are added to + a SfgVisibilityBlock which defines their syntactic position and visibility modifier + in the code. + At the top of every class, there is a default visibility block + accessible through the `default` property. + To add members with custom visibility, create a new SfgVisibilityBlock, + add members to the block, and add the block using `append_visibility_block`. + + A more succinct interface for constructing classes is available through the + [SfgClassComposer][pystencilssfg.composer.SfgClassComposer]. + """ def __init__( self, class_name: str, @@ -334,6 +395,10 @@ class SfgClass: self._class_keyword = class_keyword self._bases_classes = tuple(bases) + self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT) + self._default_block._bind(self) + self._blocks = [self._default_block] + self._definitions: list[SfgInClassDefinition] = [] self._constructors: list[SfgConstructor] = [] self._methods: dict[str, SfgMethod] = dict() @@ -355,25 +420,25 @@ class SfgClass: def class_keyword(self) -> SfgClassKeyword: return self._class_keyword + @property + def default(self) -> SfgVisibilityBlock: + return self._default_block + + def append_visibility_block(self, block: SfgVisibilityBlock): + block._bind(self) + for m in block.members(): + self._add_member(m, block.visibility) + self._blocks.append(block) + + def visibility_blocks(self) -> Generator[SfgVisibilityBlock, None, None]: + yield from self._blocks + def members( self, visibility: SfgVisibility | None = None ) -> Generator[SfgClassMember, None, None]: - yield from self.definitions(visibility) - yield from self.member_variables(visibility) - yield from self.constructors(visibility) - yield from self.methods(visibility) - - def add_member(self, member: SfgClassMember): - if isinstance(member, SfgConstructor): - self.add_constructor(member) - elif isinstance(member, SfgMemberVariable): - self.add_member_variable(member) - elif isinstance(member, SfgMethod): - self.add_method(member) - elif isinstance(member, SfgInClassDefinition): - self.add_definition(member) - else: - raise SfgException(f"{member} is not a valid class member.") + yield from chain.from_iterable( + b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks) + ) def definitions( self, visibility: SfgVisibility | None = None @@ -383,8 +448,15 @@ class SfgClass: else: yield from self._definitions - def add_definition(self, definition: SfgInClassDefinition): - self._definitions.append(definition) + def member_variables( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgMemberVariable, None, None]: + if visibility is not None: + yield from filter( + lambda m: m.visibility == visibility, self._member_vars.values() + ) + else: + yield from self._member_vars.values() def constructors( self, visibility: SfgVisibility | None = None @@ -394,10 +466,6 @@ class SfgClass: else: yield from self._constructors - def add_constructor(self, constr: SfgConstructor): - # TODO: Check for signature conflicts? - self._constructors.append(constr) - def methods( self, visibility: SfgVisibility | None = None ) -> Generator[SfgMethod, None, None]: @@ -408,7 +476,30 @@ class SfgClass: else: yield from self._methods.values() - def add_method(self, method: SfgMethod): + # PRIVATE + + def _add_member(self, member: SfgClassMember, vis: SfgVisibility): + if isinstance(member, SfgConstructor): + self._add_constructor(member) + elif isinstance(member, SfgMemberVariable): + self._add_member_variable(member) + elif isinstance(member, SfgMethod): + self._add_method(member) + elif isinstance(member, SfgInClassDefinition): + self._add_definition(member) + else: + raise SfgException(f"{member} is not a valid class member.") + + member._bind(self, vis) + + def _add_definition(self, definition: SfgInClassDefinition): + self._definitions.append(definition) + + def _add_constructor(self, constr: SfgConstructor): + # TODO: Check for signature conflicts? + self._constructors.append(constr) + + def _add_method(self, method: SfgMethod): if method.name in self._methods: raise SfgException( f"Duplicate method name {method.name} in class {self._class_name}" @@ -416,17 +507,7 @@ class SfgClass: self._methods[method.name] = method - def member_variables( - self, visibility: SfgVisibility | None = None - ) -> Generator[SfgMemberVariable, None, None]: - if visibility is not None: - yield from filter( - lambda m: m.visibility == visibility, self._member_vars.values() - ) - else: - yield from self._member_vars.values() - - def add_member_variable(self, variable: SfgMemberVariable): + def _add_member_variable(self, variable: SfgMemberVariable): if variable.name in self._member_vars: raise SfgException( f"Duplicate field name {variable.name} in class {self._class_name}"