Skip to content
Snippets Groups Projects
Commit 5c595075 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

More frontend updates

 - Add `Ref` type
 - Allow multi-arg `init` in constructor builder
 - Change `CustomGenerator` to take a composer instead of a context.
 - Allow a class to have multiple methods with the same name.
parent 7a4ff746
No related merge requests found
......@@ -65,15 +65,15 @@ class SfgNodeBuilder(ABC):
pass
_ExprLike = (str, AugExpr, TypedSymbol)
ExprLike: TypeAlias = str | AugExpr | TypedSymbol
_ExprLike = (str, AugExpr, SfgVar, TypedSymbol)
ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol
"""Things that may act as a C++ expression.
Expressions need not necesserily have a known data type.
"""
_VarLike = (TypedSymbol, AugExpr)
VarLike: TypeAlias = TypedSymbol | AugExpr
_VarLike = (TypedSymbol, SfgVar, AugExpr)
VarLike: TypeAlias = TypedSymbol | SfgVar | AugExpr
"""Things that may act as a variable.
Variables must always define their name *and* data type.
......@@ -113,7 +113,7 @@ class SfgBasicComposer(SfgIComposer):
def define(self, *definitions: str):
"""Add custom definitions to the generated header file.
Each string passed to this method will be printed out directly into the generated header file.
:Example:
......@@ -138,7 +138,7 @@ class SfgBasicComposer(SfgIComposer):
def namespace(self, namespace: str):
"""Set the inner code namespace. Throws an exception if a namespace was already set.
:Example:
After adding the following to your generator script:
......@@ -150,14 +150,15 @@ class SfgBasicComposer(SfgIComposer):
.. code-block:: C++
namespace codegen_is_awesome {
/* all generated code */
/* all generated code */
}
"""
self._ctx.set_namespace(namespace)
def generate(self, generator: CustomGenerator):
"""Invoke a custom code generator with the underlying context."""
generator.generate(self._ctx)
from .composer import SfgComposer
generator.generate(SfgComposer(self))
@property
def kernels(self) -> SfgKernelNamespace:
......@@ -254,7 +255,7 @@ class SfgBasicComposer(SfgIComposer):
if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.")
def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
def sequencer(*args: SequencerArg):
tree = make_sequence(*args)
func = SfgFunction(name, tree)
self._ctx.add_function(func)
......@@ -663,6 +664,8 @@ def struct_from_numpy_dtype(
def _asvar(var: VarLike) -> SfgVar:
match var:
case SfgVar():
return var
case AugExpr():
return var.as_variable()
case TypedSymbol():
......@@ -683,6 +686,8 @@ def _depends(expr: ExprLike) -> set[SfgVar]:
match expr:
case None | str():
return set()
case SfgVar():
return {expr}
case TypedSymbol():
return {_asvar(expr)}
case AugExpr():
......
from __future__ import annotations
from typing import Sequence
from pystencils import TypedSymbol
from pystencils.types import PsCustomType, UserTypeSpec
from ..lang import AugExpr
from ..ir import SfgCallTreeNode
from ..ir.source_components import (
SfgClass,
SfgClassMember,
......@@ -21,11 +18,11 @@ from ..exceptions import SfgException
from .mixin import SfgComposerMixIn
from .basic_composer import (
SfgNodeBuilder,
make_sequence,
_VarLike,
VarLike,
ExprLike,
SequencerArg,
_asvar,
)
......@@ -79,8 +76,8 @@ class SfgClassComposer(SfgComposerMixIn):
def init(self, var: VarLike):
"""Add an initialization expression to the constructor's initializer list."""
def init_sequencer(expr: ExprLike):
expr = str(expr)
def init_sequencer(*args: ExprLike):
expr = ", ".join(str(arg) for arg in args)
initializer = f"{_asvar(var)}{{ {expr} }}"
self._initializers.append(initializer)
return self
......@@ -159,7 +156,7 @@ class SfgClassComposer(SfgComposerMixIn):
const: Whether or not the method is const-qualified.
"""
def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder):
def sequencer(*args: SequencerArg):
tree = make_sequence(*args)
return SfgMethod(
name,
......@@ -221,7 +218,7 @@ class SfgClassComposer(SfgComposerMixIn):
arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str,
) -> SfgClassMember:
match arg:
case AugExpr() | TypedSymbol():
case _ if isinstance(arg, _VarLike):
var = _asvar(arg)
return SfgMemberVariable(var.name, var.dtype)
case str():
......
from __future__ import annotations
from abc import ABC, abstractmethod
from ..context import SfgContext
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .composer import SfgComposer
class CustomGenerator(ABC):
......@@ -7,4 +11,4 @@ class CustomGenerator(ABC):
`SfgComposer.generate`."""
@abstractmethod
def generate(self, ctx: SfgContext) -> None: ...
def generate(self, sfg: SfgComposer) -> None: ...
......@@ -520,7 +520,7 @@ class SfgClass:
self._definitions: list[SfgInClassDefinition] = []
self._constructors: list[SfgConstructor] = []
self._methods: dict[str, SfgMethod] = dict()
self._methods: list[SfgMethod] = []
self._member_vars: dict[str, SfgMemberVariable] = dict()
@property
......@@ -599,10 +599,10 @@ class SfgClass:
) -> Generator[SfgMethod, None, None]:
if visibility is not None:
yield from filter(
lambda m: m.visibility == visibility, self._methods.values()
lambda m: m.visibility == visibility, self._methods
)
else:
yield from self._methods.values()
yield from self._methods
# PRIVATE
......@@ -624,16 +624,10 @@ class SfgClass:
self._definitions.append(definition)
def _add_constructor(self, constr: SfgConstructor):
# TODO: Check for signature conflicts?
self._constructors.append(constr)
def _add_method(self, method: SfgMethod):
if method.name in self._methods:
raise SfgException(
f"Duplicate method name {method.name} in class {self._class_name}"
)
self._methods[method.name] = method
self._methods.append(method)
def _add_member_variable(self, variable: SfgMemberVariable):
if variable.name in self._member_vars:
......
......@@ -6,10 +6,13 @@ from .expressions import (
SrcVector,
)
from .types import Ref
__all__ = [
"DependentExpression",
"AugExpr",
"IFieldExtraction",
"SrcField",
"SrcVector",
"Ref",
]
from typing import Any
from pystencils.types import PsType
class Ref(PsType):
"""C++ reference type."""
__match_args__ = "base_type"
def __init__(self, base_type: PsType, const: bool = False):
super().__init__(False)
self._base_type = base_type
def __args__(self) -> tuple[Any, ...]:
return (self.base_type,)
@property
def base_type(self) -> PsType:
return self._base_type
def c_string(self) -> str:
base_str = self.base_type.c_string()
return base_str + "&"
def __repr__(self) -> str:
return f"Ref({repr(self.base_type)})"
......@@ -15,4 +15,4 @@ namespace awesome {
#define PI 3.1415
using namespace std;
}
} // namespace awesome
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment