diff --git a/integration/test_class_composer.py b/integration/test_class_composer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6f48d6b8179b029d156822932848b01ef11c0d --- /dev/null +++ b/integration/test_class_composer.py @@ -0,0 +1,55 @@ +# type: ignore +from pystencilssfg import SourceFileGenerator, SfgConfiguration, SfgComposer +from pystencilssfg.configuration import SfgCodeStyle +from pystencilssfg.composer import SfgClassComposer +from pystencilssfg.source_concepts import SrcObject + +from pystencils import fields, kernel + +sfg_config = SfgConfiguration( + output_directory="out/test_class_composer", + outer_namespace="gen_code", + codestyle=SfgCodeStyle( + code_style="Mozilla", + force_clang_format=True + ) +) + +f, g = fields("f, g(1): double[2D]") + +with SourceFileGenerator(sfg_config) as ctx: + sfg = SfgComposer(ctx) + c = SfgClassComposer(ctx) + + @kernel + def assignments(): + f[0, 0] @= 3 * g[0, 0] + + khandle = sfg.kernels.create(assignments) + + c.struct("DataStruct")( + SrcObject("coord", "uint32_t"), + SrcObject("value", "float") + ), + + c.klass("MyClass", bases=("MyBaseClass",))( + # class body sequencer + + c.private( + c.var("a_", "int"), + + c.method("getX", returns="int")( + "return 2.0;" + ) + ), + + c.constructor(SrcObject("a", "int")) + .init("a_(a)") + .body( + 'cout << "Hi!" << endl;' + ), + + c.public( + + ) + ) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index 277ce8ddcb85c5e25530e0aff5a5ccb7b827f25f..f4ef3bf33c0f035a25c475d1c0207f24f14ddbb1 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod import numpy as np +from functools import partial from pystencils import Field from pystencils.astnodes import KernelFunction @@ -23,7 +24,9 @@ from .source_components import ( SfgKernelNamespace, SfgKernelHandle, SfgClass, + SfgClassMember, SfgConstructor, + SfgMethod, SfgMemberVariable, SfgClassKeyword, SfgVisibility, @@ -330,6 +333,131 @@ def parse_include(incl: str | SfgHeaderInclude): return SfgHeaderInclude(incl, system_header=system_header) +class SfgClassComposer: + def __init__(self, ctx: SfgContext): + self._ctx = ctx + + class PartialMember: + def __init__(self, member_type: type[SfgClassMember], *args, **kwargs): + assert issubclass(member_type, SfgClassMember) + + self._type = member_type + self._partial = partial(member_type, *args, **kwargs) + + @property + def member_type(self): + return self._type + + def resolve(self, cls: SfgClass, visibility: SfgVisibility) -> SfgClassMember: + return self._partial(cls=cls, visibility=visibility) + + class VisibilityContext: + def __init__(self, visibility: SfgVisibility): + self._vis = visibility + self._partial_members: list[SfgClassComposer.PartialMember] = [] + + def members(self): + yield from self._partial_members + + def __call__(self, *args: SfgClassComposer.PartialMember | SrcObject): + for arg in args: + if isinstance(arg, SrcObject): + self._partial_members.append(SfgClassComposer.PartialMember( + SfgMemberVariable, + name=arg.name, + dtype=arg.dtype + )) + else: + self._partial_members.append(arg) + + return self + + def resolve(self, cls: SfgClass) -> list[SfgClassMember]: + return [part.resolve(cls=cls, visibility=self._vis) for part in self._partial_members] + + class ConstructorBuilder: + def __init__(self, *params: SrcObject): + self._params = params + self._initializers: list[str] = [] + + def init(self, initializer: str) -> SfgClassComposer.ConstructorBuilder: + self._initializers.append(initializer) + return self + + def body(self, body: str): + return SfgClassComposer.PartialMember( + SfgConstructor, + parameters=self._params, + initializers=self._initializers, + body=body + ) + + def klass(self, class_name: str, bases: Sequence[str] = ()): + return self._class(class_name, SfgClassKeyword.CLASS, bases) + + def struct(self, class_name: str, bases: Sequence[str] = ()): + return self._class(class_name, SfgClassKeyword.STRUCT, bases) + + @property + def public(self) -> SfgClassComposer.VisibilityContext: + return SfgClassComposer.VisibilityContext(SfgVisibility.PUBLIC) + + @property + def private(self) -> SfgClassComposer.VisibilityContext: + return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE) + + def var(self, name: str, dtype: SrcType): + return SfgClassComposer.PartialMember(SfgMemberVariable, name=name, dtype=dtype) + + def constructor(self, *params): + return SfgClassComposer.ConstructorBuilder(*params) + + def method( + self, + name: str, + returns: SrcType = SrcType("void"), + inline: bool = False, + const: bool = False): + + def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + tree = make_sequence(*args) + return SfgClassComposer.PartialMember( + SfgMethod, + name=name, + tree=tree, + return_type=returns, + inline=inline, + const=const + ) + + return sequencer + + # INTERNALS + + def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]): + if self._ctx.get_class(class_name) is not None: + raise ValueError(f"Class or struct {class_name} already exists.") + + cls = SfgClass(class_name, class_keyword=keyword, bases=bases) + self._ctx.add_class(cls) + + def sequencer(*args): + default_context = SfgClassComposer.VisibilityContext(SfgVisibility.DEFAULT) + for arg in args: + if isinstance(arg, SfgClassComposer.VisibilityContext): + for member in arg.resolve(cls): + cls.add_member(member) + elif isinstance(arg, (SfgClassComposer.PartialMember, SrcObject)): + default_context(arg) + else: + raise SfgException(f"{arg} is not a valid class member.") + + for member in default_context.resolve(cls): + cls.add_member(member) + + return sequencer + + def struct_from_numpy_dtype( struct_name: str, dtype: np.dtype, add_constructor: bool = True ): diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 74514b2e44dcde7659008620d18bb2c472d68b5c..6a771568d4af44e63a992c6517ca145c93182e95 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -151,7 +151,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): code = f"{method.return_type} {method.name} ({self.param_list(method)})" - code += "const" if method.const else ")" + code += "const" if method.const else "" if method.inline: code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" else: diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index 1346059a08cf0744d1ce5b17c0925a57fdb7a292..0c1b21d5a73e42aba22098cf45c2ff008f16c550 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -246,11 +246,11 @@ class SfgMemberVariable(SrcObject, SfgClassMember): def __init__( self, name: str, - type: SrcType, + dtype: SrcType, cls: SfgClass, visibility: SfgVisibility = SfgVisibility.PRIVATE, ): - SrcObject.__init__(self, name, type) + SrcObject.__init__(self, name, dtype) SfgClassMember.__init__(self, cls, visibility) @@ -314,6 +314,9 @@ class SfgClass: class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, bases: Sequence[str] = (), ): + if isinstance(bases, str): + raise ValueError("Base classes must be given as a sequence.") + self._class_name = class_name self._class_keyword = class_keyword self._bases_classes = tuple(bases) @@ -345,6 +348,16 @@ class SfgClass: yield from self.constructors(visibility) yield from self.methods(visibility) + def add_member(self, member: SfgClassMember): + if isinstance(member, SfgConstructor): + self.add_constructor(member) + elif isinstance(member, SfgMemberVariable): + self.add_member_variable(member) + elif isinstance(member, SfgMethod): + self.add_method(member) + else: + raise SfgException(f"{member} is not a valid class member.") + def constructors( self, visibility: SfgVisibility | None = None ) -> Generator[SfgConstructor, None, None]: