Skip to content
Snippets Groups Projects
Commit af967256 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

extended methods and functions:

 - non-void return types
 - const qualifiers
 - inline methods
parent 0cc13bbb
No related branches found
No related tags found
No related merge requests found
Pipeline #57883 passed
# type: ignore
from pystencilssfg import SourceFileGenerator, SfgConfiguration
from pystencilssfg.configuration import SfgCodeStyle
from pystencilssfg.types import SrcType
from pystencilssfg.source_concepts import SrcObject
from pystencilssfg.source_components import SfgClass, SfgMemberVariable, SfgConstructor, SfgMethod, SfgVisibility
......@@ -33,6 +34,30 @@ with SourceFileGenerator(sfg_config) as sfg:
visibility=SfgVisibility.PUBLIC
))
cls.add_method(SfgMethod(
"inlineConst",
sfg.seq(
"return -1.0;"
),
cls,
visibility=SfgVisibility.PUBLIC,
return_type=SrcType("double"),
inline=True,
const=True
))
cls.add_method(SfgMethod(
"awesomeMethod",
sfg.seq(
"return 2.0f;"
),
cls,
visibility=SfgVisibility.PRIVATE,
return_type=SrcType("float"),
inline=False,
const=True
))
cls.add_member_variable(
SfgMemberVariable(
"stuff", "std::vector< int >",
......
......@@ -60,6 +60,10 @@ class SfgGeneralPrinter:
else:
return ""
def param_list(self, func: SfgFunction) -> str:
params = sorted(list(func.parameters), key=lambda p: p.name)
return ", ".join(f"{param.dtype} {param.name}" for param in params)
class SfgHeaderPrinter(SfgGeneralPrinter):
......@@ -108,7 +112,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
def function(self, func: SfgFunction):
params = sorted(list(func.parameters), key=lambda p: p.name)
param_list = ", ".join(f"{param.dtype} {param.name}" for param in params)
return f"void {func.name} ( {param_list} );"
return f"{func.return_type} {func.name} ( {param_list} );"
@visit.case(SfgClass)
def sfg_class(self, cls: SfgClass):
......@@ -150,9 +154,12 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
@visit.case(SfgMethod)
def sfg_method(self, method: SfgMethod):
code = f"void {method.name} ("
code += ", ".join(f"{param.dtype} {param.name}" for param in method.parameters)
code += ");"
code = f"{method.return_type} {method.name} ({self.param_list(method)})"
code += "const" if method.const else ")"
if method.inline:
code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n"
else:
code += ";"
return code
......@@ -223,20 +230,19 @@ class SfgImplPrinter(SfgGeneralPrinter):
@visit.case(SfgFunction)
def function(self, func: SfgFunction) -> str:
return self.method_or_func(func, func.name)
code = f"{func.return_type} {func.name} ({self.param_list(func)})"
code += "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n"
return code
@visit.case(SfgClass)
def sfg_class(self, cls: SfgClass) -> str:
return "\n".join(self.visit(m) for m in cls.methods())
methods = filter(lambda m: not m.inline, cls.methods())
return "\n".join(self.visit(m) for m in methods)
@visit.case(SfgMethod)
def sfg_method(self, method: SfgMethod) -> str:
return self.method_or_func(method, f"{method.owning_class.class_name}::{method.name}")
def method_or_func(self, func: SfgFunction, fully_qualified_name: str) -> str:
params = sorted(list(func.parameters), key=lambda p: p.name)
param_list = ", ".join(f"{param.dtype} {param.name}" for param in params)
code = f"void {fully_qualified_name} ({param_list}) {{\n"
code += self._ctx.codestyle.indent(func.tree.get_code(self._ctx))
code += "}\n"
const_qual = "const" if method.const else ""
code = f"{method.return_type} {method.owning_class.class_name}::{method.name}"
code += f"({self.param_list(method)}) {const_qual}"
code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n"
return code
......@@ -13,7 +13,6 @@ from .source_concepts import SrcObject
from .exceptions import SfgException
if TYPE_CHECKING:
from .context import SfgContext
from .tree import SfgCallTreeNode
......@@ -173,9 +172,10 @@ class SfgKernelHandle:
class SfgFunction:
def __init__(self, name: str, tree: SfgCallTreeNode):
def __init__(self, name: str, tree: SfgCallTreeNode, return_type: SrcType = SrcType("void")):
self._name = name
self._tree = tree
self._return_type = return_type
from .visitors.tree_visitors import ExpandingParameterCollector
......@@ -194,8 +194,9 @@ class SfgFunction:
def tree(self):
return self._tree
def get_code(self, ctx: SfgContext):
return self._tree.get_code(ctx)
@property
def return_type(self) -> SrcType:
return self._return_type
class SfgVisibility(Enum):
......@@ -258,10 +259,24 @@ class SfgMethod(SfgFunction, SfgClassMember):
tree: SfgCallTreeNode,
cls: SfgClass,
visibility: SfgVisibility = SfgVisibility.PUBLIC,
return_type: SrcType = SrcType("void"),
inline: bool = False,
const: bool = False
):
SfgFunction.__init__(self, name, tree)
SfgFunction.__init__(self, name, tree, return_type=return_type)
SfgClassMember.__init__(self, cls, visibility)
self._inline = inline
self._const = const
@property
def inline(self) -> bool:
return self._inline
@property
def const(self) -> bool:
return self._const
class SfgConstructor(SfgClassMember):
def __init__(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment