diff --git a/integration/test_class_composer.py b/integration/test_class_composer.py index 2f6f48d6b8179b029d156822932848b01ef11c0d..4bb9860ba5a4bf05858f916f488788f6f652c643 100644 --- a/integration/test_class_composer.py +++ b/integration/test_class_composer.py @@ -50,6 +50,6 @@ with SourceFileGenerator(sfg_config) as ctx: ), c.public( - + "using xtype = uint8_t;" ) ) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index f4ef3bf33c0f035a25c475d1c0207f24f14ddbb1..08183d0afc98310414ebae116fc0325093cb9037 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -25,6 +25,7 @@ from .source_components import ( SfgKernelHandle, SfgClass, SfgClassMember, + SfgInClassDefinition, SfgConstructor, SfgMethod, SfgMemberVariable, @@ -359,21 +360,28 @@ class SfgClassComposer: def members(self): yield from self._partial_members - def __call__(self, *args: SfgClassComposer.PartialMember | SrcObject): + def __call__(self, *args: SfgClassComposer.PartialMember | SrcObject | str): for arg in args: if isinstance(arg, SrcObject): - self._partial_members.append(SfgClassComposer.PartialMember( - SfgMemberVariable, - name=arg.name, - dtype=arg.dtype - )) + self._partial_members.append( + SfgClassComposer.PartialMember( + SfgMemberVariable, name=arg.name, dtype=arg.dtype + ) + ) + elif isinstance(arg, str): + self._partial_members.append( + SfgClassComposer.PartialMember(SfgInClassDefinition, text=arg) + ) else: self._partial_members.append(arg) return self def resolve(self, cls: SfgClass) -> list[SfgClassMember]: - return [part.resolve(cls=cls, visibility=self._vis) for part in self._partial_members] + return [ + part.resolve(cls=cls, visibility=self._vis) + for part in self._partial_members + ] class ConstructorBuilder: def __init__(self, *params: SrcObject): @@ -389,7 +397,7 @@ class SfgClassComposer: SfgConstructor, parameters=self._params, initializers=self._initializers, - body=body + body=body, ) def klass(self, class_name: str, bases: Sequence[str] = ()): @@ -413,12 +421,12 @@ class SfgClassComposer: return SfgClassComposer.ConstructorBuilder(*params) def method( - self, - name: str, - returns: SrcType = SrcType("void"), - inline: bool = False, - const: bool = False): - + self, + name: str, + returns: SrcType = SrcType("void"), + inline: bool = False, + const: bool = False, + ): def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): tree = make_sequence(*args) return SfgClassComposer.PartialMember( @@ -427,7 +435,7 @@ class SfgClassComposer: tree=tree, return_type=returns, inline=inline, - const=const + const=const, ) return sequencer @@ -447,7 +455,7 @@ class SfgClassComposer: if isinstance(arg, SfgClassComposer.VisibilityContext): for member in arg.resolve(cls): cls.add_member(member) - elif isinstance(arg, (SfgClassComposer.PartialMember, SrcObject)): + elif isinstance(arg, (SfgClassComposer.PartialMember, SrcObject, str)): default_context(arg) else: raise SfgException(f"{arg} is not a valid class member.") diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 6a771568d4af44e63a992c6517ca145c93182e95..69c19a8d74e567b1051b0724ec2a6f52e24237a7 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -18,6 +18,7 @@ from ..source_components import ( SfgKernelNamespace, SfgFunction, SfgClass, + SfgInClassDefinition, SfgConstructor, SfgMemberVariable, SfgMethod, @@ -34,7 +35,6 @@ def interleave(*iters): class SfgGeneralPrinter: - @visitor def visit(self, obj: object) -> str: raise SfgException(f"Can't print object of type {type(obj)}") @@ -56,7 +56,11 @@ class SfgGeneralPrinter: def prelude(self, ctx: SfgContext) -> str: if ctx.prelude_comment: - return "/*\n" + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + "*/\n" + return ( + "/*\n" + + indent(ctx.prelude_comment, "* ", predicate=lambda _: True) + + "*/\n" + ) else: return "" @@ -66,7 +70,6 @@ class SfgGeneralPrinter: class SfgHeaderPrinter(SfgGeneralPrinter): - def __init__(self, ctx: SfgContext, output_spec: SfgOutputSpec): self._output_spec = output_spec self._ctx = ctx @@ -92,10 +95,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): if fq_namespace is not None: code += f"namespace {fq_namespace} {{\n\n" - parts = interleave( - ctx.declarations_ordered(), - repeat(SfgEmptyLines(1)) - ) + parts = interleave(ctx.declarations_ordered(), repeat(SfgEmptyLines(1))) code += "\n".join(self.visit(p) for p in parts) @@ -131,6 +131,10 @@ class SfgHeaderPrinter(SfgGeneralPrinter): return code + @visit.case(SfgInClassDefinition) + def sfg_inclassdef(self, definition: SfgInClassDefinition): + return definition.text + @visit.case(SfgConstructor) def sfg_constructor(self, constr: SfgConstructor): code = f"{constr.owning_class.class_name} (" @@ -153,7 +157,11 @@ class SfgHeaderPrinter(SfgGeneralPrinter): code = f"{method.return_type} {method.name} ({self.param_list(method)})" code += "const" if method.const else "" if method.inline: - code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" + code += ( + " {\n" + + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + + "}\n" + ) else: code += ";" return code @@ -201,9 +209,9 @@ class SfgImplPrinter(SfgGeneralPrinter): [delimiter("Functions")], ctx.functions(), [delimiter("Class Methods")], - ctx.classes() + ctx.classes(), ), - repeat(SfgEmptyLines(1)) + repeat(SfgEmptyLines(1)), ) code += "\n".join(self.visit(p) for p in parts) @@ -227,7 +235,9 @@ class SfgImplPrinter(SfgGeneralPrinter): @visit.case(SfgFunction) def function(self, func: SfgFunction) -> str: code = f"{func.return_type} {func.name} ({self.param_list(func)})" - code += "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n" + code += ( + "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n" + ) return code @visit.case(SfgClass) @@ -240,5 +250,7 @@ class SfgImplPrinter(SfgGeneralPrinter): const_qual = "const" if method.const else "" code = f"{method.return_type} {method.owning_class.class_name}::{method.name}" code += f"({self.param_list(method)}) {const_qual}" - code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" + code += ( + " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" + ) return code diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index 0c1b21d5a73e42aba22098cf45c2ff008f16c550..a33c98d5434f4912a4335339318d2e684ab2b67c 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -242,6 +242,19 @@ class SfgClassMember(ABC): return self._visibility +class SfgInClassDefinition(SfgClassMember): + def __init__(self, text: str, cls: SfgClass, visibility: SfgVisibility): + SfgClassMember.__init__(self, cls, visibility) + self._text = text + + @property + def text(self) -> str: + return self._text + + def __str__(self) -> str: + return self._text + + class SfgMemberVariable(SrcObject, SfgClassMember): def __init__( self, @@ -321,6 +334,7 @@ class SfgClass: self._class_keyword = class_keyword self._bases_classes = tuple(bases) + self._definitions: list[SfgInClassDefinition] = [] self._constructors: list[SfgConstructor] = [] self._methods: dict[str, SfgMethod] = dict() self._member_vars: dict[str, SfgMemberVariable] = dict() @@ -344,6 +358,7 @@ class SfgClass: def members( self, visibility: SfgVisibility | None = None ) -> Generator[SfgClassMember, None, None]: + yield from self.definitions(visibility) yield from self.member_variables(visibility) yield from self.constructors(visibility) yield from self.methods(visibility) @@ -355,9 +370,22 @@ class SfgClass: self.add_member_variable(member) elif isinstance(member, SfgMethod): self.add_method(member) + elif isinstance(member, SfgInClassDefinition): + self.add_definition(member) else: raise SfgException(f"{member} is not a valid class member.") + def definitions( + self, visibility: SfgVisibility | None = None + ) -> Generator[SfgInClassDefinition, None, None]: + if visibility is not None: + yield from filter(lambda m: m.visibility == visibility, self._definitions) + else: + yield from self._definitions + + def add_definition(self, definition: SfgInClassDefinition): + self._definitions.append(definition) + def constructors( self, visibility: SfgVisibility | None = None ) -> Generator[SfgConstructor, None, None]: