diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 09384892fb5f0975daadbbfd1fa77498c3f3bd80..35da6c5a1603756957b5a2b5752744a4f5f2f2aa 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -65,15 +65,15 @@ class SfgNodeBuilder(ABC): pass -_ExprLike = (str, AugExpr, TypedSymbol) -ExprLike: TypeAlias = str | AugExpr | TypedSymbol +_ExprLike = (str, AugExpr, SfgVar, TypedSymbol) +ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol """Things that may act as a C++ expression. Expressions need not necesserily have a known data type. """ -_VarLike = (TypedSymbol, AugExpr) -VarLike: TypeAlias = TypedSymbol | AugExpr +_VarLike = (TypedSymbol, SfgVar, AugExpr) +VarLike: TypeAlias = TypedSymbol | SfgVar | AugExpr """Things that may act as a variable. Variables must always define their name *and* data type. @@ -113,7 +113,7 @@ class SfgBasicComposer(SfgIComposer): def define(self, *definitions: str): """Add custom definitions to the generated header file. - + Each string passed to this method will be printed out directly into the generated header file. :Example: @@ -138,7 +138,7 @@ class SfgBasicComposer(SfgIComposer): def namespace(self, namespace: str): """Set the inner code namespace. Throws an exception if a namespace was already set. - + :Example: After adding the following to your generator script: @@ -150,14 +150,15 @@ class SfgBasicComposer(SfgIComposer): .. code-block:: C++ namespace codegen_is_awesome { - /* all generated code */ + /* all generated code */ } """ self._ctx.set_namespace(namespace) def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" - generator.generate(self._ctx) + from .composer import SfgComposer + generator.generate(SfgComposer(self)) @property def kernels(self) -> SfgKernelNamespace: @@ -254,7 +255,7 @@ class SfgBasicComposer(SfgIComposer): if self._ctx.get_function(name) is not None: raise ValueError(f"Function {name} already exists.") - def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + def sequencer(*args: SequencerArg): tree = make_sequence(*args) func = SfgFunction(name, tree) self._ctx.add_function(func) @@ -663,6 +664,8 @@ def struct_from_numpy_dtype( def _asvar(var: VarLike) -> SfgVar: match var: + case SfgVar(): + return var case AugExpr(): return var.as_variable() case TypedSymbol(): @@ -683,6 +686,8 @@ def _depends(expr: ExprLike) -> set[SfgVar]: match expr: case None | str(): return set() + case SfgVar(): + return {expr} case TypedSymbol(): return {_asvar(expr)} case AugExpr(): diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 9b8fcfb13e5cfa84349275735cab8ee25f53771f..ed081dc0450e0fc04e432cb07b5d8de45188c84e 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,11 +1,8 @@ from __future__ import annotations from typing import Sequence -from pystencils import TypedSymbol from pystencils.types import PsCustomType, UserTypeSpec -from ..lang import AugExpr -from ..ir import SfgCallTreeNode from ..ir.source_components import ( SfgClass, SfgClassMember, @@ -21,11 +18,11 @@ from ..exceptions import SfgException from .mixin import SfgComposerMixIn from .basic_composer import ( - SfgNodeBuilder, make_sequence, _VarLike, VarLike, ExprLike, + SequencerArg, _asvar, ) @@ -79,8 +76,8 @@ class SfgClassComposer(SfgComposerMixIn): def init(self, var: VarLike): """Add an initialization expression to the constructor's initializer list.""" - def init_sequencer(expr: ExprLike): - expr = str(expr) + def init_sequencer(*args: ExprLike): + expr = ", ".join(str(arg) for arg in args) initializer = f"{_asvar(var)}{{ {expr} }}" self._initializers.append(initializer) return self @@ -159,7 +156,7 @@ class SfgClassComposer(SfgComposerMixIn): const: Whether or not the method is const-qualified. """ - def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + def sequencer(*args: SequencerArg): tree = make_sequence(*args) return SfgMethod( name, @@ -221,7 +218,7 @@ class SfgClassComposer(SfgComposerMixIn): arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str, ) -> SfgClassMember: match arg: - case AugExpr() | TypedSymbol(): + case _ if isinstance(arg, _VarLike): var = _asvar(arg) return SfgMemberVariable(var.name, var.dtype) case str(): diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py index fa53d6af033fa52743d2f1c066e41950f9394a0e..7df364c6cd78c1a56d68283f3c617092938a4dcf 100644 --- a/src/pystencilssfg/composer/custom.py +++ b/src/pystencilssfg/composer/custom.py @@ -1,5 +1,9 @@ +from __future__ import annotations from abc import ABC, abstractmethod -from ..context import SfgContext +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .composer import SfgComposer class CustomGenerator(ABC): @@ -7,4 +11,4 @@ class CustomGenerator(ABC): `SfgComposer.generate`.""" @abstractmethod - def generate(self, ctx: SfgContext) -> None: ... + def generate(self, sfg: SfgComposer) -> None: ... diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 859f9266de18fb952405721892136bde3fda3fd9..c8a72df8ce16123f627f2ec521e44cd99ca62927 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -520,7 +520,7 @@ class SfgClass: self._definitions: list[SfgInClassDefinition] = [] self._constructors: list[SfgConstructor] = [] - self._methods: dict[str, SfgMethod] = dict() + self._methods: list[SfgMethod] = [] self._member_vars: dict[str, SfgMemberVariable] = dict() @property @@ -599,10 +599,10 @@ class SfgClass: ) -> Generator[SfgMethod, None, None]: if visibility is not None: yield from filter( - lambda m: m.visibility == visibility, self._methods.values() + lambda m: m.visibility == visibility, self._methods ) else: - yield from self._methods.values() + yield from self._methods # PRIVATE @@ -624,16 +624,10 @@ class SfgClass: self._definitions.append(definition) def _add_constructor(self, constr: SfgConstructor): - # TODO: Check for signature conflicts? self._constructors.append(constr) def _add_method(self, method: SfgMethod): - if method.name in self._methods: - raise SfgException( - f"Duplicate method name {method.name} in class {self._class_name}" - ) - - self._methods[method.name] = method + self._methods.append(method) def _add_member_variable(self, variable: SfgMemberVariable): if variable.name in self._member_vars: diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 543d3094131f7dd4d0d8edf1673c2e5a27597065..99661af8ec4ec6306b6aad3a219b88a8303712cd 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -6,10 +6,13 @@ from .expressions import ( SrcVector, ) +from .types import Ref + __all__ = [ "DependentExpression", "AugExpr", "IFieldExtraction", "SrcField", "SrcVector", + "Ref", ] diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py new file mode 100644 index 0000000000000000000000000000000000000000..6f23160075050c6dfa33fd17636c2a3f826f263a --- /dev/null +++ b/src/pystencilssfg/lang/types.py @@ -0,0 +1,26 @@ +from typing import Any +from pystencils.types import PsType + + +class Ref(PsType): + """C++ reference type.""" + + __match_args__ = "base_type" + + def __init__(self, base_type: PsType, const: bool = False): + super().__init__(False) + self._base_type = base_type + + def __args__(self) -> tuple[Any, ...]: + return (self.base_type,) + + @property + def base_type(self) -> PsType: + return self._base_type + + def c_string(self) -> str: + base_str = self.base_type.c_string() + return base_str + "&" + + def __repr__(self) -> str: + return f"Ref({repr(self.base_type)})" diff --git a/tests/generator_scripts/expected/Structural.h b/tests/generator_scripts/expected/Structural.h index 45ffcf60d4c878203fb8f9335e97081c1d4461d8..0eb1e25f0704e43172ef77b6cea7f022fd744c42 100644 --- a/tests/generator_scripts/expected/Structural.h +++ b/tests/generator_scripts/expected/Structural.h @@ -15,4 +15,4 @@ namespace awesome { #define PI 3.1415 using namespace std; -} +} // namespace awesome