Skip to content
Snippets Groups Projects
Commit 463207f2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

class composer prototype

parent b4c0f95c
No related branches found
No related tags found
No related merge requests found
Pipeline #58232 passed
# type: ignore
from pystencilssfg import SourceFileGenerator, SfgConfiguration, SfgComposer
from pystencilssfg.configuration import SfgCodeStyle
from pystencilssfg.composer import SfgClassComposer
from pystencilssfg.source_concepts import SrcObject
from pystencils import fields, kernel
sfg_config = SfgConfiguration(
output_directory="out/test_class_composer",
outer_namespace="gen_code",
codestyle=SfgCodeStyle(
code_style="Mozilla",
force_clang_format=True
)
)
f, g = fields("f, g(1): double[2D]")
with SourceFileGenerator(sfg_config) as ctx:
sfg = SfgComposer(ctx)
c = SfgClassComposer(ctx)
@kernel
def assignments():
f[0, 0] @= 3 * g[0, 0]
khandle = sfg.kernels.create(assignments)
c.struct("DataStruct")(
SrcObject("coord", "uint32_t"),
SrcObject("value", "float")
),
c.klass("MyClass", bases=("MyBaseClass",))(
# class body sequencer
c.private(
c.var("a_", "int"),
c.method("getX", returns="int")(
"return 2.0;"
)
),
c.constructor(SrcObject("a", "int"))
.init("a_(a)")
.body(
'cout << "Hi!" << endl;'
),
c.public(
)
)
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING, Sequence
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from functools import partial
from pystencils import Field from pystencils import Field
from pystencils.astnodes import KernelFunction from pystencils.astnodes import KernelFunction
...@@ -23,7 +24,9 @@ from .source_components import ( ...@@ -23,7 +24,9 @@ from .source_components import (
SfgKernelNamespace, SfgKernelNamespace,
SfgKernelHandle, SfgKernelHandle,
SfgClass, SfgClass,
SfgClassMember,
SfgConstructor, SfgConstructor,
SfgMethod,
SfgMemberVariable, SfgMemberVariable,
SfgClassKeyword, SfgClassKeyword,
SfgVisibility, SfgVisibility,
...@@ -330,6 +333,131 @@ def parse_include(incl: str | SfgHeaderInclude): ...@@ -330,6 +333,131 @@ def parse_include(incl: str | SfgHeaderInclude):
return SfgHeaderInclude(incl, system_header=system_header) return SfgHeaderInclude(incl, system_header=system_header)
class SfgClassComposer:
def __init__(self, ctx: SfgContext):
self._ctx = ctx
class PartialMember:
def __init__(self, member_type: type[SfgClassMember], *args, **kwargs):
assert issubclass(member_type, SfgClassMember)
self._type = member_type
self._partial = partial(member_type, *args, **kwargs)
@property
def member_type(self):
return self._type
def resolve(self, cls: SfgClass, visibility: SfgVisibility) -> SfgClassMember:
return self._partial(cls=cls, visibility=visibility)
class VisibilityContext:
def __init__(self, visibility: SfgVisibility):
self._vis = visibility
self._partial_members: list[SfgClassComposer.PartialMember] = []
def members(self):
yield from self._partial_members
def __call__(self, *args: SfgClassComposer.PartialMember | SrcObject):
for arg in args:
if isinstance(arg, SrcObject):
self._partial_members.append(SfgClassComposer.PartialMember(
SfgMemberVariable,
name=arg.name,
dtype=arg.dtype
))
else:
self._partial_members.append(arg)
return self
def resolve(self, cls: SfgClass) -> list[SfgClassMember]:
return [part.resolve(cls=cls, visibility=self._vis) for part in self._partial_members]
class ConstructorBuilder:
def __init__(self, *params: SrcObject):
self._params = params
self._initializers: list[str] = []
def init(self, initializer: str) -> SfgClassComposer.ConstructorBuilder:
self._initializers.append(initializer)
return self
def body(self, body: str):
return SfgClassComposer.PartialMember(
SfgConstructor,
parameters=self._params,
initializers=self._initializers,
body=body
)
def klass(self, class_name: str, bases: Sequence[str] = ()):
return self._class(class_name, SfgClassKeyword.CLASS, bases)
def struct(self, class_name: str, bases: Sequence[str] = ()):
return self._class(class_name, SfgClassKeyword.STRUCT, bases)
@property
def public(self) -> SfgClassComposer.VisibilityContext:
return SfgClassComposer.VisibilityContext(SfgVisibility.PUBLIC)
@property
def private(self) -> SfgClassComposer.VisibilityContext:
return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE)
def var(self, name: str, dtype: SrcType):
return SfgClassComposer.PartialMember(SfgMemberVariable, name=name, dtype=dtype)
def constructor(self, *params):
return SfgClassComposer.ConstructorBuilder(*params)
def method(
self,
name: str,
returns: SrcType = SrcType("void"),
inline: bool = False,
const: bool = False):
def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
tree = make_sequence(*args)
return SfgClassComposer.PartialMember(
SfgMethod,
name=name,
tree=tree,
return_type=returns,
inline=inline,
const=const
)
return sequencer
# INTERNALS
def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]):
if self._ctx.get_class(class_name) is not None:
raise ValueError(f"Class or struct {class_name} already exists.")
cls = SfgClass(class_name, class_keyword=keyword, bases=bases)
self._ctx.add_class(cls)
def sequencer(*args):
default_context = SfgClassComposer.VisibilityContext(SfgVisibility.DEFAULT)
for arg in args:
if isinstance(arg, SfgClassComposer.VisibilityContext):
for member in arg.resolve(cls):
cls.add_member(member)
elif isinstance(arg, (SfgClassComposer.PartialMember, SrcObject)):
default_context(arg)
else:
raise SfgException(f"{arg} is not a valid class member.")
for member in default_context.resolve(cls):
cls.add_member(member)
return sequencer
def struct_from_numpy_dtype( def struct_from_numpy_dtype(
struct_name: str, dtype: np.dtype, add_constructor: bool = True struct_name: str, dtype: np.dtype, add_constructor: bool = True
): ):
......
...@@ -151,7 +151,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): ...@@ -151,7 +151,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
@visit.case(SfgMethod) @visit.case(SfgMethod)
def sfg_method(self, method: SfgMethod): def sfg_method(self, method: SfgMethod):
code = f"{method.return_type} {method.name} ({self.param_list(method)})" code = f"{method.return_type} {method.name} ({self.param_list(method)})"
code += "const" if method.const else ")" code += "const" if method.const else ""
if method.inline: if method.inline:
code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n"
else: else:
......
...@@ -246,11 +246,11 @@ class SfgMemberVariable(SrcObject, SfgClassMember): ...@@ -246,11 +246,11 @@ class SfgMemberVariable(SrcObject, SfgClassMember):
def __init__( def __init__(
self, self,
name: str, name: str,
type: SrcType, dtype: SrcType,
cls: SfgClass, cls: SfgClass,
visibility: SfgVisibility = SfgVisibility.PRIVATE, visibility: SfgVisibility = SfgVisibility.PRIVATE,
): ):
SrcObject.__init__(self, name, type) SrcObject.__init__(self, name, dtype)
SfgClassMember.__init__(self, cls, visibility) SfgClassMember.__init__(self, cls, visibility)
...@@ -314,6 +314,9 @@ class SfgClass: ...@@ -314,6 +314,9 @@ class SfgClass:
class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS,
bases: Sequence[str] = (), bases: Sequence[str] = (),
): ):
if isinstance(bases, str):
raise ValueError("Base classes must be given as a sequence.")
self._class_name = class_name self._class_name = class_name
self._class_keyword = class_keyword self._class_keyword = class_keyword
self._bases_classes = tuple(bases) self._bases_classes = tuple(bases)
...@@ -345,6 +348,16 @@ class SfgClass: ...@@ -345,6 +348,16 @@ class SfgClass:
yield from self.constructors(visibility) yield from self.constructors(visibility)
yield from self.methods(visibility) yield from self.methods(visibility)
def add_member(self, member: SfgClassMember):
if isinstance(member, SfgConstructor):
self.add_constructor(member)
elif isinstance(member, SfgMemberVariable):
self.add_member_variable(member)
elif isinstance(member, SfgMethod):
self.add_method(member)
else:
raise SfgException(f"{member} is not a valid class member.")
def constructors( def constructors(
self, visibility: SfgVisibility | None = None self, visibility: SfgVisibility | None = None
) -> Generator[SfgConstructor, None, None]: ) -> Generator[SfgConstructor, None, None]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment