Skip to content
Snippets Groups Projects
Commit 5c595075 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

More frontend updates

 - Add `Ref` type
 - Allow multi-arg `init` in constructor builder
 - Change `CustomGenerator` to take a composer instead of a context.
 - Allow a class to have multiple methods with the same name.
parent 7a4ff746
No related branches found
No related tags found
No related merge requests found
Pipeline #69659 failed
...@@ -65,15 +65,15 @@ class SfgNodeBuilder(ABC): ...@@ -65,15 +65,15 @@ class SfgNodeBuilder(ABC):
pass pass
_ExprLike = (str, AugExpr, TypedSymbol) _ExprLike = (str, AugExpr, SfgVar, TypedSymbol)
ExprLike: TypeAlias = str | AugExpr | TypedSymbol ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol
"""Things that may act as a C++ expression. """Things that may act as a C++ expression.
Expressions need not necesserily have a known data type. Expressions need not necesserily have a known data type.
""" """
_VarLike = (TypedSymbol, AugExpr) _VarLike = (TypedSymbol, SfgVar, AugExpr)
VarLike: TypeAlias = TypedSymbol | AugExpr VarLike: TypeAlias = TypedSymbol | SfgVar | AugExpr
"""Things that may act as a variable. """Things that may act as a variable.
Variables must always define their name *and* data type. Variables must always define their name *and* data type.
...@@ -113,7 +113,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -113,7 +113,7 @@ class SfgBasicComposer(SfgIComposer):
def define(self, *definitions: str): def define(self, *definitions: str):
"""Add custom definitions to the generated header file. """Add custom definitions to the generated header file.
Each string passed to this method will be printed out directly into the generated header file. Each string passed to this method will be printed out directly into the generated header file.
:Example: :Example:
...@@ -138,7 +138,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -138,7 +138,7 @@ class SfgBasicComposer(SfgIComposer):
def namespace(self, namespace: str): def namespace(self, namespace: str):
"""Set the inner code namespace. Throws an exception if a namespace was already set. """Set the inner code namespace. Throws an exception if a namespace was already set.
:Example: :Example:
After adding the following to your generator script: After adding the following to your generator script:
...@@ -150,14 +150,15 @@ class SfgBasicComposer(SfgIComposer): ...@@ -150,14 +150,15 @@ class SfgBasicComposer(SfgIComposer):
.. code-block:: C++ .. code-block:: C++
namespace codegen_is_awesome { namespace codegen_is_awesome {
/* all generated code */ /* all generated code */
} }
""" """
self._ctx.set_namespace(namespace) self._ctx.set_namespace(namespace)
def generate(self, generator: CustomGenerator): def generate(self, generator: CustomGenerator):
"""Invoke a custom code generator with the underlying context.""" """Invoke a custom code generator with the underlying context."""
generator.generate(self._ctx) from .composer import SfgComposer
generator.generate(SfgComposer(self))
@property @property
def kernels(self) -> SfgKernelNamespace: def kernels(self) -> SfgKernelNamespace:
...@@ -254,7 +255,7 @@ class SfgBasicComposer(SfgIComposer): ...@@ -254,7 +255,7 @@ class SfgBasicComposer(SfgIComposer):
if self._ctx.get_function(name) is not None: if self._ctx.get_function(name) is not None:
raise ValueError(f"Function {name} already exists.") raise ValueError(f"Function {name} already exists.")
def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): def sequencer(*args: SequencerArg):
tree = make_sequence(*args) tree = make_sequence(*args)
func = SfgFunction(name, tree) func = SfgFunction(name, tree)
self._ctx.add_function(func) self._ctx.add_function(func)
...@@ -663,6 +664,8 @@ def struct_from_numpy_dtype( ...@@ -663,6 +664,8 @@ def struct_from_numpy_dtype(
def _asvar(var: VarLike) -> SfgVar: def _asvar(var: VarLike) -> SfgVar:
match var: match var:
case SfgVar():
return var
case AugExpr(): case AugExpr():
return var.as_variable() return var.as_variable()
case TypedSymbol(): case TypedSymbol():
...@@ -683,6 +686,8 @@ def _depends(expr: ExprLike) -> set[SfgVar]: ...@@ -683,6 +686,8 @@ def _depends(expr: ExprLike) -> set[SfgVar]:
match expr: match expr:
case None | str(): case None | str():
return set() return set()
case SfgVar():
return {expr}
case TypedSymbol(): case TypedSymbol():
return {_asvar(expr)} return {_asvar(expr)}
case AugExpr(): case AugExpr():
......
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import Sequence
from pystencils import TypedSymbol
from pystencils.types import PsCustomType, UserTypeSpec from pystencils.types import PsCustomType, UserTypeSpec
from ..lang import AugExpr
from ..ir import SfgCallTreeNode
from ..ir.source_components import ( from ..ir.source_components import (
SfgClass, SfgClass,
SfgClassMember, SfgClassMember,
...@@ -21,11 +18,11 @@ from ..exceptions import SfgException ...@@ -21,11 +18,11 @@ from ..exceptions import SfgException
from .mixin import SfgComposerMixIn from .mixin import SfgComposerMixIn
from .basic_composer import ( from .basic_composer import (
SfgNodeBuilder,
make_sequence, make_sequence,
_VarLike, _VarLike,
VarLike, VarLike,
ExprLike, ExprLike,
SequencerArg,
_asvar, _asvar,
) )
...@@ -79,8 +76,8 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -79,8 +76,8 @@ class SfgClassComposer(SfgComposerMixIn):
def init(self, var: VarLike): def init(self, var: VarLike):
"""Add an initialization expression to the constructor's initializer list.""" """Add an initialization expression to the constructor's initializer list."""
def init_sequencer(expr: ExprLike): def init_sequencer(*args: ExprLike):
expr = str(expr) expr = ", ".join(str(arg) for arg in args)
initializer = f"{_asvar(var)}{{ {expr} }}" initializer = f"{_asvar(var)}{{ {expr} }}"
self._initializers.append(initializer) self._initializers.append(initializer)
return self return self
...@@ -159,7 +156,7 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -159,7 +156,7 @@ class SfgClassComposer(SfgComposerMixIn):
const: Whether or not the method is const-qualified. const: Whether or not the method is const-qualified.
""" """
def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): def sequencer(*args: SequencerArg):
tree = make_sequence(*args) tree = make_sequence(*args)
return SfgMethod( return SfgMethod(
name, name,
...@@ -221,7 +218,7 @@ class SfgClassComposer(SfgComposerMixIn): ...@@ -221,7 +218,7 @@ class SfgClassComposer(SfgComposerMixIn):
arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str, arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str,
) -> SfgClassMember: ) -> SfgClassMember:
match arg: match arg:
case AugExpr() | TypedSymbol(): case _ if isinstance(arg, _VarLike):
var = _asvar(arg) var = _asvar(arg)
return SfgMemberVariable(var.name, var.dtype) return SfgMemberVariable(var.name, var.dtype)
case str(): case str():
......
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ..context import SfgContext from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .composer import SfgComposer
class CustomGenerator(ABC): class CustomGenerator(ABC):
...@@ -7,4 +11,4 @@ class CustomGenerator(ABC): ...@@ -7,4 +11,4 @@ class CustomGenerator(ABC):
`SfgComposer.generate`.""" `SfgComposer.generate`."""
@abstractmethod @abstractmethod
def generate(self, ctx: SfgContext) -> None: ... def generate(self, sfg: SfgComposer) -> None: ...
...@@ -520,7 +520,7 @@ class SfgClass: ...@@ -520,7 +520,7 @@ class SfgClass:
self._definitions: list[SfgInClassDefinition] = [] self._definitions: list[SfgInClassDefinition] = []
self._constructors: list[SfgConstructor] = [] self._constructors: list[SfgConstructor] = []
self._methods: dict[str, SfgMethod] = dict() self._methods: list[SfgMethod] = []
self._member_vars: dict[str, SfgMemberVariable] = dict() self._member_vars: dict[str, SfgMemberVariable] = dict()
@property @property
...@@ -599,10 +599,10 @@ class SfgClass: ...@@ -599,10 +599,10 @@ class SfgClass:
) -> Generator[SfgMethod, None, None]: ) -> Generator[SfgMethod, None, None]:
if visibility is not None: if visibility is not None:
yield from filter( yield from filter(
lambda m: m.visibility == visibility, self._methods.values() lambda m: m.visibility == visibility, self._methods
) )
else: else:
yield from self._methods.values() yield from self._methods
# PRIVATE # PRIVATE
...@@ -624,16 +624,10 @@ class SfgClass: ...@@ -624,16 +624,10 @@ class SfgClass:
self._definitions.append(definition) self._definitions.append(definition)
def _add_constructor(self, constr: SfgConstructor): def _add_constructor(self, constr: SfgConstructor):
# TODO: Check for signature conflicts?
self._constructors.append(constr) self._constructors.append(constr)
def _add_method(self, method: SfgMethod): def _add_method(self, method: SfgMethod):
if method.name in self._methods: self._methods.append(method)
raise SfgException(
f"Duplicate method name {method.name} in class {self._class_name}"
)
self._methods[method.name] = method
def _add_member_variable(self, variable: SfgMemberVariable): def _add_member_variable(self, variable: SfgMemberVariable):
if variable.name in self._member_vars: if variable.name in self._member_vars:
......
...@@ -6,10 +6,13 @@ from .expressions import ( ...@@ -6,10 +6,13 @@ from .expressions import (
SrcVector, SrcVector,
) )
from .types import Ref
__all__ = [ __all__ = [
"DependentExpression", "DependentExpression",
"AugExpr", "AugExpr",
"IFieldExtraction", "IFieldExtraction",
"SrcField", "SrcField",
"SrcVector", "SrcVector",
"Ref",
] ]
from typing import Any
from pystencils.types import PsType
class Ref(PsType):
"""C++ reference type."""
__match_args__ = "base_type"
def __init__(self, base_type: PsType, const: bool = False):
super().__init__(False)
self._base_type = base_type
def __args__(self) -> tuple[Any, ...]:
return (self.base_type,)
@property
def base_type(self) -> PsType:
return self._base_type
def c_string(self) -> str:
base_str = self.base_type.c_string()
return base_str + "&"
def __repr__(self) -> str:
return f"Ref({repr(self.base_type)})"
...@@ -15,4 +15,4 @@ namespace awesome { ...@@ -15,4 +15,4 @@ namespace awesome {
#define PI 3.1415 #define PI 3.1415
using namespace std; using namespace std;
} } // namespace awesome
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment