diff --git a/integration/test_classes.py b/integration/test_classes.py index 192807b98e9eb6ebe4ad240a13c22eef1bfedb71..1aaef023c8b2b32250ae6f269bf2fdc952709fc4 100644 --- a/integration/test_classes.py +++ b/integration/test_classes.py @@ -1,6 +1,7 @@ # type: ignore from pystencilssfg import SourceFileGenerator, SfgConfiguration from pystencilssfg.configuration import SfgCodeStyle +from pystencilssfg.types import SrcType from pystencilssfg.source_concepts import SrcObject from pystencilssfg.source_components import SfgClass, SfgMemberVariable, SfgConstructor, SfgMethod, SfgVisibility @@ -33,6 +34,30 @@ with SourceFileGenerator(sfg_config) as sfg: visibility=SfgVisibility.PUBLIC )) + cls.add_method(SfgMethod( + "inlineConst", + sfg.seq( + "return -1.0;" + ), + cls, + visibility=SfgVisibility.PUBLIC, + return_type=SrcType("double"), + inline=True, + const=True + )) + + cls.add_method(SfgMethod( + "awesomeMethod", + sfg.seq( + "return 2.0f;" + ), + cls, + visibility=SfgVisibility.PRIVATE, + return_type=SrcType("float"), + inline=False, + const=True + )) + cls.add_member_variable( SfgMemberVariable( "stuff", "std::vector< int >", diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index bfd48c8d52d6922ec475a06c3aee19364dadf9b8..9f0a03fb0557a759a68d064e2a7e0349e6703978 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -60,6 +60,10 @@ class SfgGeneralPrinter: else: return "" + def param_list(self, func: SfgFunction) -> str: + params = sorted(list(func.parameters), key=lambda p: p.name) + return ", ".join(f"{param.dtype} {param.name}" for param in params) + class SfgHeaderPrinter(SfgGeneralPrinter): @@ -108,7 +112,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): def function(self, func: SfgFunction): params = sorted(list(func.parameters), key=lambda p: p.name) param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) - return f"void {func.name} ( {param_list} );" + return f"{func.return_type} {func.name} ( {param_list} );" @visit.case(SfgClass) def sfg_class(self, cls: SfgClass): @@ -150,9 +154,12 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): - code = f"void {method.name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in method.parameters) - code += ");" + code = f"{method.return_type} {method.name} ({self.param_list(method)})" + code += "const" if method.const else ")" + if method.inline: + code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" + else: + code += ";" return code @@ -223,20 +230,19 @@ class SfgImplPrinter(SfgGeneralPrinter): @visit.case(SfgFunction) def function(self, func: SfgFunction) -> str: - return self.method_or_func(func, func.name) + code = f"{func.return_type} {func.name} ({self.param_list(func)})" + code += "{\n" + self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) + "}\n" + return code @visit.case(SfgClass) def sfg_class(self, cls: SfgClass) -> str: - return "\n".join(self.visit(m) for m in cls.methods()) + methods = filter(lambda m: not m.inline, cls.methods()) + return "\n".join(self.visit(m) for m in methods) @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod) -> str: - return self.method_or_func(method, f"{method.owning_class.class_name}::{method.name}") - - def method_or_func(self, func: SfgFunction, fully_qualified_name: str) -> str: - params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) - code = f"void {fully_qualified_name} ({param_list}) {{\n" - code += self._ctx.codestyle.indent(func.tree.get_code(self._ctx)) - code += "}\n" + const_qual = "const" if method.const else "" + code = f"{method.return_type} {method.owning_class.class_name}::{method.name}" + code += f"({self.param_list(method)}) {const_qual}" + code += " {\n" + self._ctx.codestyle.indent(method.tree.get_code(self._ctx)) + "}\n" return code diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index b921e36d73d2eedc5eee87eac275d273c31c20f6..1bf7242ce6ca8099d04706d058e3c97ac30b3ace 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -13,7 +13,6 @@ from .source_concepts import SrcObject from .exceptions import SfgException if TYPE_CHECKING: - from .context import SfgContext from .tree import SfgCallTreeNode @@ -173,9 +172,10 @@ class SfgKernelHandle: class SfgFunction: - def __init__(self, name: str, tree: SfgCallTreeNode): + def __init__(self, name: str, tree: SfgCallTreeNode, return_type: SrcType = SrcType("void")): self._name = name self._tree = tree + self._return_type = return_type from .visitors.tree_visitors import ExpandingParameterCollector @@ -194,8 +194,9 @@ class SfgFunction: def tree(self): return self._tree - def get_code(self, ctx: SfgContext): - return self._tree.get_code(ctx) + @property + def return_type(self) -> SrcType: + return self._return_type class SfgVisibility(Enum): @@ -258,10 +259,24 @@ class SfgMethod(SfgFunction, SfgClassMember): tree: SfgCallTreeNode, cls: SfgClass, visibility: SfgVisibility = SfgVisibility.PUBLIC, + return_type: SrcType = SrcType("void"), + inline: bool = False, + const: bool = False ): - SfgFunction.__init__(self, name, tree) + SfgFunction.__init__(self, name, tree, return_type=return_type) SfgClassMember.__init__(self, cls, visibility) + self._inline = inline + self._const = const + + @property + def inline(self) -> bool: + return self._inline + + @property + def const(self) -> bool: + return self._const + class SfgConstructor(SfgClassMember): def __init__(