Skip to content
Snippets Groups Projects
Commit bf1eb5a3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactoring of class modelling

parent 5c29754f
No related branches found
No related tags found
No related merge requests found
Pipeline #58278 failed
......@@ -28,51 +28,41 @@ with SourceFileGenerator(sfg_config) as ctx:
khandle = sfg.kernels.create(assignments)
cls = SfgClass("MyClass")
cls.add_method(SfgMethod(
cls.default.append_member(SfgMethod(
"callKernel",
sfg.call(khandle),
cls,
visibility=SfgVisibility.PUBLIC
sfg.call(khandle)
))
cls.add_method(SfgMethod(
cls.default.append_member(SfgMethod(
"inlineConst",
sfg.seq(
"return -1.0;"
),
cls,
visibility=SfgVisibility.PUBLIC,
return_type=SrcType("double"),
inline=True,
const=True
))
cls.add_method(SfgMethod(
cls.default.append_member(SfgMethod(
"awesomeMethod",
sfg.seq(
"return 2.0f;"
),
cls,
visibility=SfgVisibility.PRIVATE,
return_type=SrcType("float"),
inline=False,
const=True
))
cls.add_member_variable(
cls.default.append_member(
SfgMemberVariable(
"stuff", "std::vector< int >",
cls,
SfgVisibility.PRIVATE
"stuff", "std::vector< int >"
)
)
cls.add_constructor(
cls.default.append_member(
SfgConstructor(
cls,
[SrcObject("stuff", "std::vector< int > &")],
["stuff_(stuff)"],
visibility=SfgVisibility.PUBLIC
["stuff_(stuff)"]
)
)
......
......@@ -487,13 +487,13 @@ def struct_from_numpy_dtype(
arg = SrcObject(f"{member_name}_", member_type)
cls.add_member_variable(member)
cls._add_member_variable(member)
constr_params.append(arg)
constr_inits.append(f"{member}({arg})")
if add_constructor:
cls.add_constructor(
cls._add_constructor(
SfgConstructor(
cls, constr_params, constr_inits, visibility=SfgVisibility.DEFAULT
)
......
......@@ -23,6 +23,7 @@ from ..source_components import (
SfgMemberVariable,
SfgMethod,
SfgVisibility,
SfgVisibilityBlock
)
......@@ -118,19 +119,24 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
code += f" : {','.join(cls.base_classes)}\n"
code += "{\n"
for visibility in (
SfgVisibility.DEFAULT,
SfgVisibility.PUBLIC,
SfgVisibility.PRIVATE,
):
if visibility != SfgVisibility.DEFAULT:
code += f"\n{visibility}:\n"
for member in cls.members(visibility):
code += self._ctx.codestyle.indent(self.visit(member)) + "\n"
for block in cls.visibility_blocks():
code += self.visit(block) + "\n"
code += "};\n"
return code
@visit.case(SfgVisibilityBlock)
def vis_block(self, block: SfgVisibilityBlock) -> str:
code = ""
if block.visibility != SfgVisibility.DEFAULT:
code += f"{block.visibility}:\n"
code += self._ctx.codestyle.indent(
"\n".join(self.visit(m) for m in block.members())
)
return code
@visit.case(SfgInClassDefinition)
def sfg_inclassdef(self, definition: SfgInClassDefinition):
return definition.text
......
......@@ -4,6 +4,7 @@ from abc import ABC
from enum import Enum, auto
from typing import TYPE_CHECKING, Sequence, Generator
from dataclasses import replace
from itertools import chain
from pystencils import CreateKernelConfig, create_kernel
from pystencils.astnodes import KernelFunction
......@@ -204,6 +205,7 @@ class SfgFunction:
class SfgVisibility(Enum):
DEFAULT = auto()
PRIVATE = auto()
PROTECTED = auto()
PUBLIC = auto()
def __str__(self) -> str:
......@@ -212,6 +214,8 @@ class SfgVisibility(Enum):
return ""
case SfgVisibility.PRIVATE:
return "private"
case SfgVisibility.PROTECTED:
return "protected"
case SfgVisibility.PUBLIC:
return "public"
......@@ -229,22 +233,70 @@ class SfgClassKeyword(Enum):
class SfgClassMember(ABC):
def __init__(self, cls: SfgClass, visibility: SfgVisibility):
self._cls = cls
self._visibility = visibility
def __init__(self):
self._cls: SfgClass | None = None
self._visibility: SfgVisibility | None = None
@property
def owning_class(self) -> SfgClass:
if self._cls is None:
raise SfgException(f"{self} is not bound to a class.")
return self._cls
@property
def visibility(self) -> SfgVisibility:
if self._visibility is None:
raise SfgException(f"{self} is not bound to a class and therefore has no visibility.")
return self._visibility
@property
def is_bound(self) -> bool:
return self._cls is not None
def _bind(self, cls: SfgClass, vis: SfgVisibility):
if self.is_bound:
raise SfgException(
f"Binding {self} to class {cls.class_name} failed: "
f"{self} was already bound to {self.owning_class.class_name}"
)
self._cls = cls
self._vis = vis
class SfgVisibilityBlock:
def __init__(self, visibility: SfgVisibility) -> None:
self._vis = visibility
self._members: list[SfgClassMember] = []
self._cls: SfgClass | None = None
@property
def visibility(self) -> SfgVisibility:
return self._vis
def append_member(self, member: SfgClassMember):
if self._cls is not None:
self._cls._add_member(member, self._vis)
self._members.append(member)
def members(self) -> Generator[SfgClassMember, None, None]:
yield from self._members
@property
def is_bound(self) -> bool:
return self._cls is not None
def _bind(self, cls: SfgClass):
if self._cls is not None:
raise SfgException(
f"Binding visibility block to class {cls.class_name} failed: "
f"was already bound to {self._cls.class_name}"
)
self._cls = cls
class SfgInClassDefinition(SfgClassMember):
def __init__(self, text: str, cls: SfgClass, visibility: SfgVisibility):
SfgClassMember.__init__(self, cls, visibility)
def __init__(self, text: str):
SfgClassMember.__init__(self)
self._text = text
@property
......@@ -259,12 +311,10 @@ class SfgMemberVariable(SrcObject, SfgClassMember):
def __init__(
self,
name: str,
dtype: SrcType,
cls: SfgClass,
visibility: SfgVisibility = SfgVisibility.PRIVATE,
dtype: SrcType
):
SrcObject.__init__(self, name, dtype)
SfgClassMember.__init__(self, cls, visibility)
SfgClassMember.__init__(self)
class SfgMethod(SfgFunction, SfgClassMember):
......@@ -272,14 +322,12 @@ class SfgMethod(SfgFunction, SfgClassMember):
self,
name: str,
tree: SfgCallTreeNode,
cls: SfgClass,
visibility: SfgVisibility = SfgVisibility.PUBLIC,
return_type: SrcType = SrcType("void"),
inline: bool = False,
const: bool = False,
):
SfgFunction.__init__(self, name, tree, return_type=return_type)
SfgClassMember.__init__(self, cls, visibility)
SfgClassMember.__init__(self)
self._inline = inline
self._const = const
......@@ -296,13 +344,11 @@ class SfgMethod(SfgFunction, SfgClassMember):
class SfgConstructor(SfgClassMember):
def __init__(
self,
cls: SfgClass,
parameters: Sequence[SrcObject] = (),
initializers: Sequence[str] = (),
body: str = "",
visibility: SfgVisibility = SfgVisibility.PUBLIC,
body: str = ""
):
SfgClassMember.__init__(self, cls, visibility)
SfgClassMember.__init__(self)
self._parameters = tuple(parameters)
self._initializers = tuple(initializers)
self._body = body
......@@ -321,6 +367,21 @@ class SfgConstructor(SfgClassMember):
class SfgClass:
"""Models a C++ class.
### Adding members to classes
Members are never added directly to a class. Instead, they are added to
a SfgVisibilityBlock which defines their syntactic position and visibility modifier
in the code.
At the top of every class, there is a default visibility block
accessible through the `default` property.
To add members with custom visibility, create a new SfgVisibilityBlock,
add members to the block, and add the block using `append_visibility_block`.
A more succinct interface for constructing classes is available through the
[SfgClassComposer][pystencilssfg.composer.SfgClassComposer].
"""
def __init__(
self,
class_name: str,
......@@ -334,6 +395,10 @@ class SfgClass:
self._class_keyword = class_keyword
self._bases_classes = tuple(bases)
self._default_block = SfgVisibilityBlock(SfgVisibility.DEFAULT)
self._default_block._bind(self)
self._blocks = [self._default_block]
self._definitions: list[SfgInClassDefinition] = []
self._constructors: list[SfgConstructor] = []
self._methods: dict[str, SfgMethod] = dict()
......@@ -355,25 +420,25 @@ class SfgClass:
def class_keyword(self) -> SfgClassKeyword:
return self._class_keyword
@property
def default(self) -> SfgVisibilityBlock:
return self._default_block
def append_visibility_block(self, block: SfgVisibilityBlock):
block._bind(self)
for m in block.members():
self._add_member(m, block.visibility)
self._blocks.append(block)
def visibility_blocks(self) -> Generator[SfgVisibilityBlock, None, None]:
yield from self._blocks
def members(
self, visibility: SfgVisibility | None = None
) -> Generator[SfgClassMember, None, None]:
yield from self.definitions(visibility)
yield from self.member_variables(visibility)
yield from self.constructors(visibility)
yield from self.methods(visibility)
def add_member(self, member: SfgClassMember):
if isinstance(member, SfgConstructor):
self.add_constructor(member)
elif isinstance(member, SfgMemberVariable):
self.add_member_variable(member)
elif isinstance(member, SfgMethod):
self.add_method(member)
elif isinstance(member, SfgInClassDefinition):
self.add_definition(member)
else:
raise SfgException(f"{member} is not a valid class member.")
yield from chain.from_iterable(
b.members() for b in filter(lambda b: b.visibility == visibility, self._blocks)
)
def definitions(
self, visibility: SfgVisibility | None = None
......@@ -383,8 +448,15 @@ class SfgClass:
else:
yield from self._definitions
def add_definition(self, definition: SfgInClassDefinition):
self._definitions.append(definition)
def member_variables(
self, visibility: SfgVisibility | None = None
) -> Generator[SfgMemberVariable, None, None]:
if visibility is not None:
yield from filter(
lambda m: m.visibility == visibility, self._member_vars.values()
)
else:
yield from self._member_vars.values()
def constructors(
self, visibility: SfgVisibility | None = None
......@@ -394,10 +466,6 @@ class SfgClass:
else:
yield from self._constructors
def add_constructor(self, constr: SfgConstructor):
# TODO: Check for signature conflicts?
self._constructors.append(constr)
def methods(
self, visibility: SfgVisibility | None = None
) -> Generator[SfgMethod, None, None]:
......@@ -408,7 +476,30 @@ class SfgClass:
else:
yield from self._methods.values()
def add_method(self, method: SfgMethod):
# PRIVATE
def _add_member(self, member: SfgClassMember, vis: SfgVisibility):
if isinstance(member, SfgConstructor):
self._add_constructor(member)
elif isinstance(member, SfgMemberVariable):
self._add_member_variable(member)
elif isinstance(member, SfgMethod):
self._add_method(member)
elif isinstance(member, SfgInClassDefinition):
self._add_definition(member)
else:
raise SfgException(f"{member} is not a valid class member.")
member._bind(self, vis)
def _add_definition(self, definition: SfgInClassDefinition):
self._definitions.append(definition)
def _add_constructor(self, constr: SfgConstructor):
# TODO: Check for signature conflicts?
self._constructors.append(constr)
def _add_method(self, method: SfgMethod):
if method.name in self._methods:
raise SfgException(
f"Duplicate method name {method.name} in class {self._class_name}"
......@@ -416,17 +507,7 @@ class SfgClass:
self._methods[method.name] = method
def member_variables(
self, visibility: SfgVisibility | None = None
) -> Generator[SfgMemberVariable, None, None]:
if visibility is not None:
yield from filter(
lambda m: m.visibility == visibility, self._member_vars.values()
)
else:
yield from self._member_vars.values()
def add_member_variable(self, variable: SfgMemberVariable):
def _add_member_variable(self, variable: SfgMemberVariable):
if variable.name in self._member_vars:
raise SfgException(
f"Duplicate field name {variable.name} in class {self._class_name}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment