diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 35da6c5a1603756957b5a2b5752744a4f5f2f2aa..b791f1cd225751236e18d7d9ed354e10f61bf356 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -5,7 +5,7 @@ import numpy as np import sympy as sp from functools import reduce -from pystencils import Field, TypedSymbol +from pystencils import Field from pystencils.backend import KernelFunction from pystencils.types import ( create_type, @@ -46,7 +46,18 @@ from ..ir.source_components import ( SfgClassKeyword, SfgVar, ) -from ..lang import IFieldExtraction, SrcVector, AugExpr, SrcField +from ..lang import ( + VarLike, + ExprLike, + _VarLike, + _ExprLike, + asvar, + depends, + IFieldExtraction, + SrcVector, + AugExpr, + SrcField, +) from ..exceptions import SfgException @@ -65,20 +76,6 @@ class SfgNodeBuilder(ABC): pass -_ExprLike = (str, AugExpr, SfgVar, TypedSymbol) -ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol -"""Things that may act as a C++ expression. - -Expressions need not necesserily have a known data type. -""" - -_VarLike = (TypedSymbol, SfgVar, AugExpr) -VarLike: TypeAlias = TypedSymbol | SfgVar | AugExpr -"""Things that may act as a variable. - -Variables must always define their name *and* data type. -""" - _SequencerArg = (tuple, ExprLike, SfgCallTreeNode, SfgNodeBuilder) SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder """Valid arguments to `make_sequence` and any sequencer that uses it.""" @@ -158,6 +155,7 @@ class SfgBasicComposer(SfgIComposer): def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" from .composer import SfgComposer + generator.generate(SfgComposer(self)) @property @@ -285,12 +283,12 @@ class SfgBasicComposer(SfgIComposer): tpb_str = str(threads_per_block) stream_str = str(stream) if stream is not None else None - depends = _depends(num_blocks) | _depends(threads_per_block) + deps = depends(num_blocks) | depends(threads_per_block) if stream is not None: - depends |= _depends(stream) + deps |= depends(stream) return SfgCudaKernelInvocation( - kernel_handle, num_blocks_str, tpb_str, stream_str, depends + kernel_handle, num_blocks_str, tpb_str, stream_str, deps ) def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: @@ -362,7 +360,7 @@ class SfgBasicComposer(SfgIComposer): SomeClass obj { arg1, arg2, arg3 }; """ - return SfgInplaceInitBuilder(_asvar(lhs)) + return SfgInplaceInitBuilder(asvar(lhs)) def expr(self, fmt: str, *deps, **kwdeps) -> AugExpr: """Create an expression while keeping track of variables it depends on. @@ -430,11 +428,11 @@ class SfgBasicComposer(SfgIComposer): return SfgDeferredFieldMapping(field, index_provider) def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): - depends = _depends(expr) + deps = depends(expr) var: SfgVar | sp.Symbol = ( - _asvar(param) if isinstance(param, _VarLike) else param + asvar(param) if isinstance(param, _VarLike) else param ) - return SfgDeferredParamSetter(var, depends, str(expr)) + return SfgDeferredParamSetter(var, deps, str(expr)) def map_param( self, @@ -447,10 +445,10 @@ class SfgBasicComposer(SfgIComposer): if isinstance(depends, _VarLike): depends = [depends] lhs_var: SfgVar | sp.Symbol = ( - _asvar(param) if isinstance(param, _VarLike) else param + asvar(param) if isinstance(param, _VarLike) else param ) return SfgDeferredParamMapping( - lhs_var, set(_asvar(v) for v in depends), mapping + lhs_var, set(asvar(v) for v in depends), mapping ) def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector): @@ -461,14 +459,13 @@ class SfgBasicComposer(SfgIComposer): rhs: A `SrcVector` object representing a vector data structure. """ components: list[SfgVar | sp.Symbol] = [ - (_asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components + (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components ] return SfgDeferredVectorMapping(components, rhs) def make_statements(arg: ExprLike) -> SfgStatements: - depends = _depends(arg) - return SfgStatements(str(arg), (), depends) + return SfgStatements(str(arg), (), depends(arg)) def make_sequence(*args: SequencerArg) -> SfgSequence: @@ -538,7 +535,7 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: class SfgInplaceInitBuilder(SfgNodeBuilder): def __init__(self, lhs: SfgVar) -> None: self._lhs: SfgVar = lhs - self._depends: set[SfgVar] = set() + self.depends: set[SfgVar] = set() self._rhs: str | None = None def __call__( @@ -549,7 +546,7 @@ class SfgInplaceInitBuilder(SfgNodeBuilder): raise SfgException("Assignment builder used multiple times.") self._rhs = ", ".join(str(expr) for expr in rhs) - self._depends = reduce(set.union, (_depends(obj) for obj in rhs), set()) + self.depends = reduce(set.union, (depends(obj) for obj in rhs), set()) return self def resolve(self) -> SfgCallTreeNode: @@ -557,7 +554,7 @@ class SfgInplaceInitBuilder(SfgNodeBuilder): return SfgStatements( f"{self._lhs.dtype} {self._lhs.name} {{ {self._rhs} }};", [self._lhs], - self._depends, + self.depends, ) @@ -660,37 +657,3 @@ def struct_from_numpy_dtype( cls.default.append_member(SfgConstructor(constr_params, constr_inits)) return cls - - -def _asvar(var: VarLike) -> SfgVar: - match var: - case SfgVar(): - return var - case AugExpr(): - return var.as_variable() - case TypedSymbol(): - from pystencils import DynamicType - - if isinstance(var.dtype, DynamicType): - raise SfgException( - f"Unable to cast dynamically typed symbol {var} to a variable.\n" - f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel." - ) - - return SfgVar(var.name, var.dtype) - case _: - raise ValueError(f"Invalid variable: {var}") - - -def _depends(expr: ExprLike) -> set[SfgVar]: - match expr: - case None | str(): - return set() - case SfgVar(): - return {expr} - case TypedSymbol(): - return {_asvar(expr)} - case AugExpr(): - return expr.depends - case _: - raise ValueError(f"Invalid expression: {expr}") diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index ed081dc0450e0fc04e432cb07b5d8de45188c84e..bd906782d6e074955c60221d9a8f4d9b15bae772 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -3,6 +3,13 @@ from typing import Sequence from pystencils.types import PsCustomType, UserTypeSpec +from ..lang import ( + _VarLike, + VarLike, + ExprLike, + asvar, +) + from ..ir.source_components import ( SfgClass, SfgClassMember, @@ -19,11 +26,7 @@ from ..exceptions import SfgException from .mixin import SfgComposerMixIn from .basic_composer import ( make_sequence, - _VarLike, - VarLike, - ExprLike, SequencerArg, - _asvar, ) @@ -69,7 +72,7 @@ class SfgClassComposer(SfgComposerMixIn): """ def __init__(self, *params: VarLike): - self._params = tuple(_asvar(p) for p in params) + self._params = tuple(asvar(p) for p in params) self._initializers: list[str] = [] self._body: str | None = None @@ -78,7 +81,7 @@ class SfgClassComposer(SfgComposerMixIn): def init_sequencer(*args: ExprLike): expr = ", ".join(str(arg) for arg in args) - initializer = f"{_asvar(var)}{{ {expr} }}" + initializer = f"{asvar(var)}{{ {expr} }}" self._initializers.append(initializer) return self @@ -219,7 +222,7 @@ class SfgClassComposer(SfgComposerMixIn): ) -> SfgClassMember: match arg: case _ if isinstance(arg, _VarLike): - var = _asvar(arg) + var = asvar(arg) return SfgMemberVariable(var.name, var.dtype) case str(): return SfgInClassDefinition(arg) diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 99661af8ec4ec6306b6aad3a219b88a8303712cd..83f8e5a787b12913e96b45d4c6483611d100c8be 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -1,5 +1,10 @@ from .expressions import ( - DependentExpression, + VarLike, + _VarLike, + ExprLike, + _ExprLike, + asvar, + depends, AugExpr, IFieldExtraction, SrcField, @@ -9,7 +14,12 @@ from .expressions import ( from .types import Ref __all__ = [ - "DependentExpression", + "VarLike", + "_VarLike", + "ExprLike", + "_ExprLike", + "asvar", + "depends", "AugExpr", "IFieldExtraction", "SrcField", diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 948bbd68601a26e91afc61af3fe771cc85ef2b00..aee3375d27126de4128a68245833e407a0fef4f0 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Iterable +from typing import Iterable, TypeAlias from itertools import chain from abc import ABC, abstractmethod +from pystencils import TypedSymbol from pystencils.types import PsType from ..ir.source_components import SfgVar, SfgHeaderInclude @@ -61,6 +62,17 @@ class VarExpr(DependentExpression): class AugExpr: + """C++ Expression augmented with variable dependencies and a type-dependent interface. + + `AugExpr` is the primary class for modelling C++ expressions in *pystencils-sfg*. + It stores both an expression's code string and the set of variables (`SfgVar`) + the expression depends on. This dependency information is used by the postprocessing + system to infer function parameter lists. + + In addition, subclasses of `AugExpr` can mimic C++ APIs by defining factory methods that + build expressions for C++ method calls, etc., from a list of argument expressions. + """ + __match_args__ = ("expr", "dtype") def __init__(self, dtype: PsType | None = None): @@ -146,6 +158,87 @@ class AugExpr: return self._bound is not None +_VarLike = (AugExpr, SfgVar, TypedSymbol) +VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol +"""Things that may act as a variable. + +Variable-like objects are entities from pystencils and pystencils-sfg that define +a variable name and data type. +Any `VarLike` object can be transformed into a canonical representation (i.e. `SfgVar`) +using `asvar`. +""" + + +_ExprLike = (str, AugExpr, SfgVar, TypedSymbol) +ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol +"""Things that may act as a C++ expression. + +This type combines all objects that *pystencils-sfg* can handle in the place of C++ +expressions. These include all valid variable types (`VarLike`), plain strings, and +complex expressions with variable dependency information ('AugExpr`). + +The set of variables an expression depends on can be determined using `depends`. +""" + + +def asvar(var: VarLike) -> SfgVar: + """Cast a variable-like object to its canonical representation, + + Args: + var: Variable-like object + + Returns: + SfgVar: Variable cast as `SfgVar`. + + Raises: + SfgException: If given a non-variable `AugExpr`, or a `TypedSymbol` with a `DynamicType` + ValueError: If given any non-variable-like object. + """ + match var: + case SfgVar(): + return var + case AugExpr(): + return var.as_variable() + case TypedSymbol(): + from pystencils import DynamicType + + if isinstance(var.dtype, DynamicType): + raise SfgException( + f"Unable to cast dynamically typed symbol {var} to a variable.\n" + f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel." + ) + + return SfgVar(var.name, var.dtype) + case _: + raise ValueError(f"Invalid variable: {var}") + + +def depends(expr: ExprLike) -> set[SfgVar]: + """Determine the set of variables an expression depends on. + + Args: + expr: Expression-like object to examine + + Returns: + set[SfgVar]: Set of variables the expression depends on + + Raises: + ValueError: If the argument was not a valid expression + """ + + match expr: + case None | str(): + return set() + case SfgVar(): + return {expr} + case TypedSymbol(): + return {asvar(expr)} + case AugExpr(): + return expr.depends + case _: + raise ValueError(f"Invalid expression: {expr}") + + class IFieldExtraction(ABC): """Interface for objects defining how to extract low-level field parameters from high-level data structures."""