diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index b791f1cd225751236e18d7d9ed354e10f61bf356..4bdf8b30cc8cf7043d862b72af8d5f2b1f27c227 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -44,7 +44,6 @@ from ..ir.source_components import ( SfgConstructor, SfgMemberVariable, SfgClassKeyword, - SfgVar, ) from ..lang import ( VarLike, @@ -53,10 +52,11 @@ from ..lang import ( _ExprLike, asvar, depends, - IFieldExtraction, - SrcVector, + SfgVar, AugExpr, SrcField, + IFieldExtraction, + SrcVector, ) from ..exceptions import SfgException @@ -71,6 +71,8 @@ class SfgIComposer(ABC): class SfgNodeBuilder(ABC): + """Base class for node builders used by the composer""" + @abstractmethod def resolve(self) -> SfgCallTreeNode: pass @@ -211,7 +213,7 @@ class SfgBasicComposer(SfgIComposer): if self._ctx.get_class(name) is not None: raise SfgException(f"Class with name {name} already exists.") - cls = struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) + cls = _struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) self._ctx.add_class(cls) return cls @@ -344,7 +346,7 @@ class SfgBasicComposer(SfgIComposer): varnames = names.split(",") return tuple(self.var(n.strip(), dtype) for n in varnames) - def init(self, lhs: VarLike) -> SfgInplaceInitBuilder: + def init(self, lhs: VarLike): """Create a C++ in-place initialization. Usage: @@ -360,7 +362,18 @@ class SfgBasicComposer(SfgIComposer): SomeClass obj { arg1, arg2, arg3 }; """ - return SfgInplaceInitBuilder(asvar(lhs)) + lhs_var = asvar(lhs) + + def parse_args(*args: ExprLike): + args_str = ", ".join(str(arg) for arg in args) + deps = reduce(set.union, (depends(arg) for arg in args), set()) + return SfgStatements( + f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};", + (lhs_var,), + deps, + ) + + return parse_args def expr(self, fmt: str, *deps, **kwdeps) -> AugExpr: """Create an expression while keeping track of variables it depends on. @@ -429,9 +442,7 @@ class SfgBasicComposer(SfgIComposer): def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): deps = depends(expr) - var: SfgVar | sp.Symbol = ( - asvar(param) if isinstance(param, _VarLike) else param - ) + var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param return SfgDeferredParamSetter(var, deps, str(expr)) def map_param( @@ -447,9 +458,7 @@ class SfgBasicComposer(SfgIComposer): lhs_var: SfgVar | sp.Symbol = ( asvar(param) if isinstance(param, _VarLike) else param ) - return SfgDeferredParamMapping( - lhs_var, set(asvar(v) for v in depends), mapping - ) + return SfgDeferredParamMapping(lhs_var, set(asvar(v) for v in depends), mapping) def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector): """Extracts scalar numerical values from a vector data type. @@ -532,33 +541,9 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: return SfgSequence(children) -class SfgInplaceInitBuilder(SfgNodeBuilder): - def __init__(self, lhs: SfgVar) -> None: - self._lhs: SfgVar = lhs - self.depends: set[SfgVar] = set() - self._rhs: str | None = None - - def __call__( - self, - *rhs: ExprLike, - ) -> SfgInplaceInitBuilder: - if self._rhs is not None: - raise SfgException("Assignment builder used multiple times.") - - self._rhs = ", ".join(str(expr) for expr in rhs) - self.depends = reduce(set.union, (depends(obj) for obj in rhs), set()) - return self - - def resolve(self) -> SfgCallTreeNode: - assert self._rhs is not None - return SfgStatements( - f"{self._lhs.dtype} {self._lhs.name} {{ {self._rhs} }};", - [self._lhs], - self.depends, - ) - - class SfgBranchBuilder(SfgNodeBuilder): + """Multi-call builder for C++ ``if/else`` statements.""" + def __init__(self) -> None: self._phase = 0 @@ -595,6 +580,7 @@ class SfgBranchBuilder(SfgNodeBuilder): class SfgSwitchBuilder(SfgNodeBuilder): + """Builder for C++ switches.""" def __init__(self, switch_arg: ExprLike): self._switch_arg = switch_arg self._cases: dict[str, SfgSequence] = dict() @@ -629,7 +615,7 @@ class SfgSwitchBuilder(SfgNodeBuilder): return SfgSwitch(make_statements(self._switch_arg), self._cases, self._default) -def struct_from_numpy_dtype( +def _struct_from_numpy_dtype( struct_name: str, dtype: np.dtype, add_constructor: bool = True ): cls = SfgClass(struct_name, class_keyword=SfgClassKeyword.STRUCT) diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index cc3f83fa6a049febd76d1e62dcdda39d4a502575..af59f4f82b1451928b3e3bbdb6835d4ce92f33c5 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -16,14 +16,14 @@ from ..composer import ( SfgComposerMixIn, ) from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude -from ..ir.source_components import SfgVar, SfgSymbolLike +from ..ir.source_components import SfgSymbolLike from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, SfgKernelCallNode, ) -from ..lang import AugExpr +from ..lang import SfgVar, AugExpr class SyclComposerMixIn(SfgComposerMixIn): diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index 43660069ce5e049f60b42193698427f1756d2fae..1ae1749367527472aaf5ba77ddf09a8081ed578b 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -19,7 +19,6 @@ from .source_components import ( SfgEmptyLines, SfgKernelNamespace, SfgKernelHandle, - SfgVar, SfgSymbolLike, SfgFunction, SfgVisibility, @@ -51,7 +50,6 @@ __all__ = [ "SfgEmptyLines", "SfgKernelNamespace", "SfgKernelHandle", - "SfgVar", "SfgSymbolLike", "SfgFunction", "SfgVisibility", diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index 4db0daf2a37c0b7e11c24e3d79735378aa78096b..4a084649b068487a5e712fb1d7f49f87153bc5ee 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Sequence, Iterable, NewType from abc import ABC, abstractmethod from itertools import chain -from .source_components import SfgHeaderInclude, SfgKernelHandle, SfgVar +from .source_components import SfgHeaderInclude, SfgKernelHandle +from ..lang import SfgVar if TYPE_CHECKING: from ..context import SfgContext diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index c037c522324771b455bd5796de45af8bac0164df..6de8e184231a420709999ad094136154a813ed33 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -19,8 +19,8 @@ from pystencils.backend.kernelfunction import ( from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements -from ..ir.source_components import SfgVar, SfgSymbolLike -from ..lang import IFieldExtraction, SrcField, SrcVector +from ..ir.source_components import SfgSymbolLike +from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector if TYPE_CHECKING: from ..context import SfgContext diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index c8a72df8ce16123f627f2ec521e44cd99ca62927..5007851eeb9027ef025c5e46c3d6da4c019cf2c0 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC from enum import Enum, auto -from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic, Any +from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic from dataclasses import replace from itertools import chain @@ -14,6 +14,7 @@ from pystencils.backend.kernelfunction import ( ) from pystencils.types import PsType, PsCustomType +from ..lang import SfgVar from ..exceptions import SfgException if TYPE_CHECKING: @@ -31,6 +32,7 @@ class SfgEmptyLines: class SfgHeaderInclude: + """Represent ``#include``-directives.""" @staticmethod def parse(incl: str | SfgHeaderInclude, private: bool = False): @@ -203,56 +205,6 @@ class SfgKernelHandle: return self._namespace.get_kernel_function(self) -class SfgVar: - __match_args__ = ("name", "dtype") - - def __init__( - self, - name: str, - dtype: PsType, - required_includes: set[SfgHeaderInclude] | None = None, - ): - self._name = name - self._dtype = dtype - - self._required_includes = ( - required_includes if required_includes is not None else set() - ) - - @property - def name(self) -> str: - return self._name - - @property - def dtype(self) -> PsType: - return self._dtype - - def _args(self) -> tuple[Any, ...]: - return (self._name, self._dtype) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SfgVar): - return False - - return self._args() == other._args() - - def __hash__(self) -> int: - return hash(self._args()) - - @property - def required_includes(self) -> set[SfgHeaderInclude]: - return self._required_includes - - def name_and_type(self) -> str: - return f"{self._name}: {self._dtype}" - - def __str__(self) -> str: - return self._name - - def __repr__(self) -> str: - return f"{self._name}: {self._dtype}" - - SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter) diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 83f8e5a787b12913e96b45d4c6483611d100c8be..d67ffa0c845b16c867064365cec8b5af5ffe6a2a 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -1,11 +1,12 @@ from .expressions import ( + SfgVar, + AugExpr, VarLike, _VarLike, ExprLike, _ExprLike, asvar, depends, - AugExpr, IFieldExtraction, SrcField, SrcVector, @@ -14,13 +15,14 @@ from .expressions import ( from .types import Ref __all__ = [ + "SfgVar", + "AugExpr", "VarLike", "_VarLike", "ExprLike", "_ExprLike", "asvar", "depends", - "AugExpr", "IFieldExtraction", "SrcField", "SrcVector", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index aee3375d27126de4128a68245833e407a0fef4f0..1a15fcd84f3c8200bfdd44df20c4c4b7f2ecc00f 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -1,14 +1,89 @@ from __future__ import annotations -from typing import Iterable, TypeAlias +from typing import Iterable, TypeAlias, Any, TYPE_CHECKING from itertools import chain from abc import ABC, abstractmethod from pystencils import TypedSymbol from pystencils.types import PsType -from ..ir.source_components import SfgVar, SfgHeaderInclude from ..exceptions import SfgException +if TYPE_CHECKING: + from ..ir.source_components import SfgHeaderInclude + + +__all__ = [ + "SfgVar", + "AugExpr", + "VarLike", + "ExprLike", + "asvar", + "depends", + "IFieldExtraction", + "SrcField", + "SrcVector", +] + + +class SfgVar: + """C++ Variable. + + Args: + name: Name of the variable. Must be a valid C++ identifer. + dtype: Data type of the variable. + """ + + __match_args__ = ("name", "dtype") + + def __init__( + self, + name: str, + dtype: PsType, + required_includes: set[SfgHeaderInclude] | None = None, + ): + # TODO: Replace `required_includes` by using a property + # Includes attached this way may currently easily be lost during postprocessing, + # since they are not part of `_args` + self._name = name + self._dtype = dtype + + self._required_includes = ( + required_includes if required_includes is not None else set() + ) + + @property + def name(self) -> str: + return self._name + + @property + def dtype(self) -> PsType: + return self._dtype + + def _args(self) -> tuple[Any, ...]: + return (self._name, self._dtype) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SfgVar): + return False + + return self._args() == other._args() + + def __hash__(self) -> int: + return hash(self._args()) + + @property + def required_includes(self) -> set[SfgHeaderInclude]: + return self._required_includes + + def name_and_type(self) -> str: + return f"{self._name}: {self._dtype}" + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"{self._name}: {self._dtype}" + class DependentExpression: __match_args__ = ("expr", "depends") @@ -62,7 +137,7 @@ class VarExpr(DependentExpression): class AugExpr: - """C++ Expression augmented with variable dependencies and a type-dependent interface. + """C++ expression augmented with variable dependencies and a type-dependent interface. `AugExpr` is the primary class for modelling C++ expressions in *pystencils-sfg*. It stores both an expression's code string and the set of variables (`SfgVar`) @@ -71,6 +146,9 @@ class AugExpr: In addition, subclasses of `AugExpr` can mimic C++ APIs by defining factory methods that build expressions for C++ method calls, etc., from a list of argument expressions. + + Args: + dtype: Optional, data type of this expression interface """ __match_args__ = ("expr", "dtype") @@ -175,7 +253,7 @@ ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol This type combines all objects that *pystencils-sfg* can handle in the place of C++ expressions. These include all valid variable types (`VarLike`), plain strings, and -complex expressions with variable dependency information ('AugExpr`). +complex expressions with variable dependency information (`AugExpr`). The set of variables an expression depends on can be determined using `depends`. """ @@ -257,7 +335,11 @@ class IFieldExtraction(ABC): class SrcField(AugExpr): - """Represents a C++ data structure that can be mapped to a *pystencils* field.""" + """Represents a C++ data structure that can be mapped to a *pystencils* field. + + Args: + dtype: Data type of the field data structure + """ @abstractmethod def get_extraction(self) -> IFieldExtraction: @@ -265,7 +347,11 @@ class SrcField(AugExpr): class SrcVector(AugExpr, ABC): - """Represents a C++ data structure that represents a mathematical vector.""" + """Represents a C++ data structure that represents a mathematical vector. + + Args: + dtype: Data type of the vector data structure + """ @abstractmethod def extract_component(self, coordinate: int) -> AugExpr: