Skip to content
Snippets Groups Projects
Commit efb220a6 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Upgrade `sfg.method` sequencer.

parent b0a97749
No related branches found
No related tags found
1 merge request!21Composer API Extensions and How-To Guide
Pipeline #74324 passed
......@@ -610,10 +610,10 @@ def make_sequence(*args: SequencerArg) -> SfgSequence:
return SfgSequence(children)
class SfgFunctionSequencer:
"""Sequencer for constructing free functions.
class SfgFunctionSequencerBase:
"""Common base class for function and method sequencers.
This builder uses call sequencing to specify the function's properties.
This builder uses call sequencing to specify the function or method's properties.
Example:
......@@ -641,12 +641,12 @@ class SfgFunctionSequencer:
# Attributes
self._attributes: list[str] = []
def returns(self, rtype: UserTypeSpec) -> SfgFunctionSequencer:
def returns(self, rtype: UserTypeSpec):
"""Set the return type of the function"""
self._return_type = create_type(rtype)
return self
def params(self, *args: VarLike) -> SfgFunctionSequencer:
def params(self, *args: VarLike):
"""Specify the parameters for this function.
Use this to manually specify the function's parameter list.
......@@ -657,21 +657,25 @@ class SfgFunctionSequencer:
self._params = [asvar(v) for v in args]
return self
def inline(self) -> SfgFunctionSequencer:
def inline(self):
"""Mark this function as ``inline``."""
self._inline = True
return self
def constexpr(self) -> SfgFunctionSequencer:
def constexpr(self):
"""Mark this function as ``constexpr``."""
self._constexpr = True
return self
def attr(self, *attrs: str) -> SfgFunctionSequencer:
def attr(self, *attrs: str):
"""Add attributes to this function"""
self._attributes += attrs
return self
class SfgFunctionSequencer(SfgFunctionSequencerBase):
"""Sequencer for constructing functions."""
def __call__(self, *args: SequencerArg) -> None:
"""Populate the function body"""
tree = make_sequence(*args)
......
......@@ -3,9 +3,9 @@ from typing import Sequence
from itertools import takewhile, dropwhile
import numpy as np
from pystencils.types import PsCustomType, UserTypeSpec, create_type
from pystencils.types import create_type
from ..context import SfgContext
from ..context import SfgContext, SfgCursor
from ..lang import (
VarLike,
ExprLike,
......@@ -32,9 +32,69 @@ from .mixin import SfgComposerMixIn
from .basic_composer import (
make_sequence,
SequencerArg,
SfgFunctionSequencerBase,
)
class SfgMethodSequencer(SfgFunctionSequencerBase):
def __init__(self, cursor: SfgCursor, name: str) -> None:
super().__init__(cursor, name)
self._const: bool = False
self._static: bool = False
self._virtual: bool = False
self._override: bool = False
self._tree: SfgCallTreeNode
def const(self):
"""Mark this method as ``const``."""
self._const = True
return self
def static(self):
"""Mark this method as ``static``."""
self._static = True
return self
def virtual(self):
"""Mark this method as ``virtual``."""
self._virtual = True
return self
def override(self):
"""Mark this method as ``override``."""
self._override = True
return self
def __call__(self, *args: SequencerArg):
self._tree = make_sequence(*args)
return self
def _resolve(self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock):
method = SfgMethod(
self._name,
cls,
self._tree,
return_type=self._return_type,
inline=self._inline,
const=self._const,
static=self._static,
constexpr=self._constexpr,
virtual=self._virtual,
override=self._override,
attributes=self._attributes,
required_params=self._params,
)
cls.add_member(method, vis_block.visibility)
if self._inline:
vis_block.elements.append(SfgEntityDef(method))
else:
vis_block.elements.append(SfgEntityDecl(method))
ctx._cursor.write_impl(SfgEntityDef(method))
class SfgClassComposer(SfgComposerMixIn):
"""Composer for classes and structs.
......@@ -53,7 +113,7 @@ class SfgClassComposer(SfgComposerMixIn):
def __init__(self, visibility: SfgVisibility):
self._visibility = visibility
self._args: tuple[
SfgClassComposer.MethodSequencer
SfgMethodSequencer
| SfgClassComposer.ConstructorBuilder
| VarLike
| str,
......@@ -63,10 +123,7 @@ class SfgClassComposer(SfgComposerMixIn):
def __call__(
self,
*args: (
SfgClassComposer.MethodSequencer
| SfgClassComposer.ConstructorBuilder
| VarLike
| str
SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str
),
):
self._args = args
......@@ -76,10 +133,7 @@ class SfgClassComposer(SfgComposerMixIn):
vis_block = SfgVisibilityBlock(self._visibility)
for arg in self._args:
match arg:
case (
SfgClassComposer.MethodSequencer()
| SfgClassComposer.ConstructorBuilder()
):
case SfgMethodSequencer() | SfgClassComposer.ConstructorBuilder():
arg._resolve(ctx, cls, vis_block)
case str():
vis_block.elements.append(arg)
......@@ -90,43 +144,6 @@ class SfgClassComposer(SfgComposerMixIn):
vis_block.elements.append(SfgEntityDef(member_var))
return vis_block
class MethodSequencer:
def __init__(
self,
name: str,
returns: UserTypeSpec = PsCustomType("void"),
inline: bool = False,
const: bool = False,
) -> None:
self._name = name
self._returns = create_type(returns)
self._inline = inline
self._const = const
self._tree: SfgCallTreeNode
def __call__(self, *args: SequencerArg):
self._tree = make_sequence(*args)
return self
def _resolve(
self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock
):
method = SfgMethod(
self._name,
cls,
self._tree,
return_type=self._returns,
inline=self._inline,
const=self._const,
)
cls.add_member(method, vis_block.visibility)
if self._inline:
vis_block.elements.append(SfgEntityDef(method))
else:
vis_block.elements.append(SfgEntityDecl(method))
ctx._cursor.write_impl(SfgEntityDef(method))
class ConstructorBuilder:
"""Composer syntax for constructor building.
......@@ -197,9 +214,7 @@ class SfgClassComposer(SfgComposerMixIn):
"""
return self._class(class_name, SfgClassKeyword.STRUCT, bases)
def numpy_struct(
self, name: str, dtype: np.dtype, add_constructor: bool = True
):
def numpy_struct(self, name: str, dtype: np.dtype, add_constructor: bool = True):
"""Add a numpy structured data type as a C++ struct
Returns:
......@@ -230,13 +245,7 @@ class SfgClassComposer(SfgComposerMixIn):
"""
return SfgClassComposer.ConstructorBuilder(*params)
def method(
self,
name: str,
returns: UserTypeSpec = PsCustomType("void"),
inline: bool = False,
const: bool = False,
):
def method(self, name: str):
"""In a class or struct body or visibility block, add a method.
The usage is similar to :any:`SfgBasicComposer.function`.
......@@ -247,7 +256,7 @@ class SfgClassComposer(SfgComposerMixIn):
const: Whether or not the method is const-qualified.
"""
return SfgClassComposer.MethodSequencer(name, returns, inline, const)
return SfgMethodSequencer(self._cursor, name)
# INTERNALS
......@@ -270,7 +279,7 @@ class SfgClassComposer(SfgComposerMixIn):
def sequencer(
*args: (
SfgClassComposer.VisibilityBlockSequencer
| SfgClassComposer.MethodSequencer
| SfgMethodSequencer
| SfgClassComposer.ConstructorBuilder
| VarLike
| str
......
......@@ -183,11 +183,14 @@ class SfgFilePrinter:
if func.attributes:
code += "[[" + ", ".join(func.attributes) + "]]"
if func.inline:
if func.inline and not inclass:
code += "inline "
if isinstance(func, SfgMethod) and func.static:
code += "static "
if isinstance(func, SfgMethod) and inclass:
if func.static:
code += "static "
if func.virtual:
code += "virtual "
if func.constexpr:
code += "constexpr "
......@@ -200,7 +203,10 @@ class SfgFilePrinter:
code += f"{func.owning_class.name}::"
code += f"{func.name}({params_str})"
if isinstance(func, SfgMethod) and func.const:
code += " const"
if isinstance(func, SfgMethod):
if func.const:
code += " const"
if func.override and inclass:
code += " override"
return code
......@@ -204,7 +204,7 @@ class SfgKernelNamespace(SfgNamespace):
self._kernels[kernel.name] = kernel
@dataclass(frozen=True)
@dataclass(frozen=True, match_args=False)
class CommonFunctionProperties:
tree: SfgCallTreeNode
parameters: tuple[SfgVar, ...]
......@@ -213,6 +213,26 @@ class CommonFunctionProperties:
constexpr: bool
attributes: Sequence[str]
@staticmethod
def collect_params(tree: SfgCallTreeNode, required_params: Sequence[SfgVar] | None):
from .postprocessing import CallTreePostProcessing
param_collector = CallTreePostProcessing()
params_set = param_collector(tree).function_params
if required_params is not None:
if not (params_set <= set(required_params)):
extras = params_set - set(required_params)
raise SfgException(
"Extraenous function parameters: "
f"Found free variables {extras} that were not listed in manually specified function parameters."
)
parameters = tuple(required_params)
else:
parameters = tuple(sorted(params_set, key=lambda p: p.name))
return parameters
class SfgFunction(SfgCodeEntity, CommonFunctionProperties):
"""A free function."""
......@@ -232,21 +252,7 @@ class SfgFunction(SfgCodeEntity, CommonFunctionProperties):
):
super().__init__(name, namespace)
from .postprocessing import CallTreePostProcessing
param_collector = CallTreePostProcessing()
params_set = param_collector(tree).function_params
if required_params is not None:
if not (params_set <= set(required_params)):
extras = params_set - set(required_params)
raise SfgException(
"Extraenous function parameters: "
f"Found free variables {extras} that were not listed in manually specified function parameters."
)
parameters = tuple(required_params)
else:
parameters = tuple(sorted(params_set, key=lambda p: p.name))
parameters = self.collect_params(tree, required_params)
CommonFunctionProperties.__init__(
self,
......@@ -349,21 +355,20 @@ class SfgMethod(SfgClassMember, CommonFunctionProperties):
const: bool = False,
static: bool = False,
constexpr: bool = False,
virtual: bool = False,
override: bool = False,
attributes: Sequence[str] = (),
required_params: Sequence[SfgVar] | None = None,
):
super().__init__(cls)
self._name = name
from .postprocessing import CallTreePostProcessing
param_collector = CallTreePostProcessing()
parameters = tuple(
sorted(param_collector(tree).function_params, key=lambda p: p.name)
)
self._static = static
self._const = const
self._virtual = virtual
self._override = override
parameters = self.collect_params(tree, required_params)
CommonFunctionProperties.__init__(
self,
......@@ -387,6 +392,14 @@ class SfgMethod(SfgClassMember, CommonFunctionProperties):
def const(self) -> bool:
return self._const
@property
def virtual(self) -> bool:
return self._virtual
@property
def override(self) -> bool:
return self._override
class SfgConstructor(SfgClassMember):
"""Constructor of a class"""
......
......@@ -48,6 +48,10 @@ SimpleClasses:
output-mode: header-only
ComposerFeatures:
expect-code:
hpp:
- regex: >-
\[\[nodiscard\]\]\s*static\s*double\s*geometric\(\s*double\s*q,\s*uint64_t\s*k\)
Conditionals:
expect-code:
......
#include "ComposerFeatures.hpp"
/* factorial is constexpr -> evaluate at compile-time */
#include <cmath>
#undef NDEBUG
#include <cassert>
/* Evaluate constexpr functions at compile-time */
static_assert( factorial(0) == 1 );
static_assert( factorial(1) == 1 );
static_assert( factorial(2) == 2 );
......@@ -8,6 +13,23 @@ static_assert( factorial(3) == 6 );
static_assert( factorial(4) == 24 );
static_assert( factorial(5) == 120 );
static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 0) - 1.0) < 1e-10 );
static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 1) - 1.5) < 1e-10 );
static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 2) - 1.75) < 1e-10 );
static_assert( ConstexprMath::abs(ConstexprMath::geometric(0.5, 3) - 1.875) < 1e-10 );
int main(void) {
return 0;
assert( std::fabs(Series::geometric(0.5, 0) - 1.0) < 1e-10 );
assert( std::fabs(Series::geometric(0.5, 1) - 1.5) < 1e-10 );
assert( std::fabs(Series::geometric(0.5, 2) - 1.75) < 1e-10 );
assert( std::fabs(Series::geometric(0.5, 3) - 1.875) < 1e-10 );
inheritance_test::Parent p;
assert( p.compute() == 24 );
inheritance_test::Child c;
assert( c.compute() == 31 );
auto & cp = dynamic_cast< inheritance_test::Parent & >(c);
assert( cp.compute() == 31 );
}
......@@ -3,11 +3,67 @@ from pystencilssfg import SourceFileGenerator
with SourceFileGenerator() as sfg:
# Inline constexpr function with explicit parameter list
sfg.function("factorial").params(sfg.var("n", "uint64")).returns("uint64").inline().constexpr()(
sfg.branch("n == 0")(
"return 1;"
)(
"return n * factorial(n - 1);"
sfg.function("factorial").params(sfg.var("n", "uint64")).returns(
"uint64"
).inline().constexpr()(
sfg.branch("n == 0")("return 1;")("return n * factorial(n - 1);")
)
q = sfg.var("q", "double")
k = sfg.var("k", "uint64_t")
x = sfg.var("x", "double")
sfg.include("<cmath>")
sfg.struct("Series")(
sfg.method("geometric")
.static()
.attr("nodiscard")
.params(q, k)
.returns("double")(
sfg.branch("k == 0")(
"return 1.0;"
)(
"return Series::geometric(q, k - 1) + std::pow(q, k);"
)
)
)
sfg.struct("ConstexprMath")(
sfg.method("abs").static().constexpr().inline()
.params(x)
.returns("double")
(
"if (x >= 0.0) return x; else return -x;"
),
sfg.method("geometric")
.static()
.constexpr()
.inline()
.params(q, k)
.returns("double")(
sfg.branch("k == 0")(
"return 1.0;"
)(
"return 1 + q * ConstexprMath::geometric(q, k - 1);"
)
)
)
with sfg.namespace("inheritance_test"):
sfg.klass("Parent")(
sfg.public(
sfg.method("compute").returns("int").virtual().const()(
"return 24;"
)
)
)
sfg.klass("Child", bases=["public Parent"])(
sfg.public(
sfg.method("compute").returns("int").override().const()(
"return 31;"
)
)
)
......@@ -12,7 +12,7 @@ with SourceFileGenerator() as sfg:
sfg.klass("Point")(
sfg.public(
sfg.constructor(x, y, z).init(x_)(x).init(y_)(y).init(z_)(z),
sfg.method("getX", returns="const int64_t", const=True, inline=True)(
sfg.method("getX").returns("const int64_t").const().inline()(
"return this->x_;"
),
),
......@@ -22,7 +22,7 @@ with SourceFileGenerator() as sfg:
sfg.klass("SpecialPoint", bases=["public Point"])(
sfg.public(
"using Point::Point;",
sfg.method("getY", returns="const int64_t", const=True, inline=True)(
sfg.method("getY").returns("const int64_t").const().inline()(
"return this->y_;"
),
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment