From ed11c419a61e6c46145e99eaeeb14293cc184eda Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 13 Dec 2023 09:59:51 +0100 Subject: [PATCH] updated composer and fixed a few bugs --- integration/test_class_composer.py | 12 +-- src/pystencilssfg/composer.py | 137 +++++++++++++------------ src/pystencilssfg/source_components.py | 10 +- 3 files changed, 84 insertions(+), 75 deletions(-) diff --git a/integration/test_class_composer.py b/integration/test_class_composer.py index 4bb9860..07039b7 100644 --- a/integration/test_class_composer.py +++ b/integration/test_class_composer.py @@ -35,6 +35,12 @@ with SourceFileGenerator(sfg_config) as ctx: c.klass("MyClass", bases=("MyBaseClass",))( # class body sequencer + c.constructor(SrcObject("a", "int")) + .init("a_(a)") + .body( + 'cout << "Hi!" << endl;' + ), + c.private( c.var("a_", "int"), @@ -43,12 +49,6 @@ with SourceFileGenerator(sfg_config) as ctx: ) ), - c.constructor(SrcObject("a", "int")) - .init("a_(a)") - .body( - 'cout << "Hi!" << endl;' - ), - c.public( "using xtype = uint8_t;" ) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index e12f1d9..c46ba07 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -2,7 +2,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod import numpy as np -from functools import partial from pystencils import Field from pystencils.astnodes import KernelFunction @@ -31,6 +30,7 @@ from .source_components import ( SfgMemberVariable, SfgClassKeyword, SfgVisibility, + SfgVisibilityBlock, ) from .source_concepts import SrcObject, SrcField, TypedSymbolOrObject, SrcVector from .types import cpp_typename, SrcType @@ -338,66 +338,47 @@ class SfgClassComposer: def __init__(self, ctx: SfgContext): self._ctx = ctx - class PartialMember: - def __init__(self, member_type: type[SfgClassMember], *args, **kwargs): - assert issubclass(member_type, SfgClassMember) - - self._type = member_type - self._partial = partial(member_type, *args, **kwargs) - - @property - def member_type(self): - return self._type - - def resolve(self, cls: SfgClass, visibility: SfgVisibility) -> SfgClassMember: - return self._partial(cls=cls, visibility=visibility) - class VisibilityContext: def __init__(self, visibility: SfgVisibility): - self._vis = visibility - self._partial_members: list[SfgClassComposer.PartialMember] = [] + self._vis_block = SfgVisibilityBlock(visibility) def members(self): - yield from self._partial_members - - def __call__(self, *args: SfgClassComposer.PartialMember | SrcObject | str): + yield from self._vis_block.members() + + def __call__( + self, + *args: SfgClassMember + | SfgClassComposer.ConstructorBuilder + | SrcObject + | str, + ): for arg in args: - if isinstance(arg, SrcObject): - self._partial_members.append( - SfgClassComposer.PartialMember( - SfgMemberVariable, name=arg.name, dtype=arg.dtype - ) - ) - elif isinstance(arg, str): - self._partial_members.append( - SfgClassComposer.PartialMember(SfgInClassDefinition, text=arg) - ) - else: - self._partial_members.append(arg) + self._vis_block.append_member(SfgClassComposer._resolve_member(arg)) return self - def resolve(self, cls: SfgClass) -> list[SfgClassMember]: - return [ - part.resolve(cls=cls, visibility=self._vis) - for part in self._partial_members - ] + def resolve(self, cls: SfgClass) -> None: + cls.append_visibility_block(self._vis_block) class ConstructorBuilder: def __init__(self, *params: SrcObject): self._params = params self._initializers: list[str] = [] + self._body = "" def init(self, initializer: str) -> SfgClassComposer.ConstructorBuilder: self._initializers.append(initializer) return self def body(self, body: str): - return SfgClassComposer.PartialMember( - SfgConstructor, + self._body = body + return self + + def resolve(self) -> SfgConstructor: + return SfgConstructor( parameters=self._params, initializers=self._initializers, - body=body, + body=self._body, ) def klass(self, class_name: str, bases: Sequence[str] = ()): @@ -410,12 +391,16 @@ class SfgClassComposer: def public(self) -> SfgClassComposer.VisibilityContext: return SfgClassComposer.VisibilityContext(SfgVisibility.PUBLIC) + @property + def protected(self) -> SfgClassComposer.VisibilityContext: + return SfgClassComposer.VisibilityContext(SfgVisibility.PROTECTED) + @property def private(self) -> SfgClassComposer.VisibilityContext: return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE) def var(self, name: str, dtype: SrcType): - return SfgClassComposer.PartialMember(SfgMemberVariable, name=name, dtype=dtype) + return SfgMemberVariable(name, dtype) def constructor(self, *params): return SfgClassComposer.ConstructorBuilder(*params) @@ -429,13 +414,8 @@ class SfgClassComposer: ): def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): tree = make_sequence(*args) - return SfgClassComposer.PartialMember( - SfgMethod, - name=name, - tree=tree, - return_type=returns, - inline=inline, - const=const, + return SfgMethod( + name, tree, return_type=returns, inline=inline, const=const ) return sequencer @@ -449,22 +429,53 @@ class SfgClassComposer: cls = SfgClass(class_name, class_keyword=keyword, bases=bases) self._ctx.add_class(cls) - def sequencer(*args): - default_context = SfgClassComposer.VisibilityContext(SfgVisibility.DEFAULT) + def sequencer( + *args: SfgClassComposer.VisibilityContext + | SfgClassMember + | SfgClassComposer.ConstructorBuilder + | SrcObject + | str, + ): + default_ended = False + for arg in args: if isinstance(arg, SfgClassComposer.VisibilityContext): - for member in arg.resolve(cls): - cls.add_member(member) - elif isinstance(arg, (SfgClassComposer.PartialMember, SrcObject, str)): - default_context(arg) + default_ended = True + arg.resolve(cls) + elif isinstance( + arg, + ( + SfgClassMember, + SfgClassComposer.ConstructorBuilder, + SrcObject, + str, + ), + ): + if default_ended: + raise SfgException( + "Composer Syntax Error: " + "Cannot add members with default visibility after a visibility block." + ) + else: + cls.default.append_member(self._resolve_member(arg)) else: raise SfgException(f"{arg} is not a valid class member.") - for member in default_context.resolve(cls): - cls.add_member(member) - return sequencer + @staticmethod + def _resolve_member( + arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SrcObject | str, + ): + if isinstance(arg, SrcObject): + return SfgMemberVariable(arg.name, arg.dtype) + elif isinstance(arg, str): + return SfgInClassDefinition(arg) + elif isinstance(arg, SfgClassComposer.ConstructorBuilder): + return arg.resolve() + else: + return arg + def struct_from_numpy_dtype( struct_name: str, dtype: np.dtype, add_constructor: bool = True @@ -481,22 +492,16 @@ def struct_from_numpy_dtype( for member_name, type_info in fields.items(): member_type = SrcType(cpp_typename(type_info[0])) - member = SfgMemberVariable( - member_name, member_type, cls, visibility=SfgVisibility.DEFAULT - ) + member = SfgMemberVariable(member_name, member_type) arg = SrcObject(f"{member_name}_", member_type) - cls._add_member_variable(member) + cls.default.append_member(member) constr_params.append(arg) constr_inits.append(f"{member}({arg})") if add_constructor: - cls._add_constructor( - SfgConstructor( - cls, constr_params, constr_inits, visibility=SfgVisibility.DEFAULT - ) - ) + cls.default.append_member(SfgConstructor(constr_params, constr_inits)) return cls diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index b398c3d..e753b37 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -233,7 +233,7 @@ class SfgClassKeyword(Enum): class SfgClassMember(ABC): - def __init__(self): + def __init__(self) -> None: self._cls: SfgClass | None = None self._visibility: SfgVisibility | None = None @@ -372,8 +372,8 @@ class SfgClass: ### Adding members to classes Members are never added directly to a class. Instead, they are added to - a SfgVisibilityBlock which defines their syntactic position and visibility modifier - in the code. + an [SfgVisibilityBlock][pystencilssfg.source_components.SfgVisibilityBlock] + which defines their syntactic position and visibility modifier in the code. At the top of every class, there is a default visibility block accessible through the `default` property. To add members with custom visibility, create a new SfgVisibilityBlock, @@ -425,6 +425,10 @@ class SfgClass: return self._default_block def append_visibility_block(self, block: SfgVisibilityBlock): + if block.visibility == SfgVisibility.DEFAULT: + raise SfgException( + "Can't add another block with DEFAULT visibility to a class. Use `.default` instead.") + block._bind(self) for m in block.members(): self._add_member(m, block.visibility) -- GitLab