Coverage for src/pystencilssfg/lang/expressions.py: 89%
237 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
3from typing import Iterable, TypeAlias, Any, cast
4from itertools import chain
6import sympy as sp
8from pystencils import TypedSymbol
9from pystencils.codegen import Parameter
10from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type
12from ..exceptions import SfgException
13from .headers import HeaderFile
14from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype
17class SfgVar:
18 """C++ Variable.
20 Args:
21 name: Name of the variable. Must be a valid C++ identifer.
22 dtype: Data type of the variable.
23 """
25 __match_args__ = ("name", "dtype")
27 def __init__(
28 self,
29 name: str,
30 dtype: UserTypeSpec,
31 ):
32 self._name = name
33 self._dtype = create_type(dtype)
35 @property
36 def name(self) -> str:
37 return self._name
39 @property
40 def dtype(self) -> PsType:
41 return self._dtype
43 def _args(self) -> tuple[Any, ...]:
44 return (self._name, self._dtype)
46 def __eq__(self, other: object) -> bool:
47 if not isinstance(other, SfgVar):
48 return False
50 return self._args() == other._args()
52 def __hash__(self) -> int:
53 return hash(self._args())
55 def name_and_type(self) -> str:
56 return f"{self._name}: {self._dtype}"
58 def __str__(self) -> str:
59 return self._name
61 def __repr__(self) -> str:
62 return self.name_and_type()
65class SfgKernelParamVar(SfgVar):
66 __match_args__ = ("wrapped",)
68 """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`."""
70 def __init__(self, param: Parameter):
71 self._param = param
72 super().__init__(param.name, param.dtype)
74 @property
75 def wrapped(self) -> Parameter:
76 return self._param
78 def _args(self):
79 return (self._param,)
82class DependentExpression:
83 """Wrapper around a C++ expression code string,
84 annotated with a set of variables and a set of header files this expression depends on.
86 Args:
87 expr: C++ Code string of the expression
88 depends: Iterable of variables and/or `AugExpr` from which variable and header dependencies are collected
89 includes: Iterable of header files which this expression additionally depends on
90 """
92 __match_args__ = ("expr", "depends")
94 def __init__(
95 self,
96 expr: str,
97 depends: Iterable[SfgVar | AugExpr],
98 includes: Iterable[HeaderFile] | None = None,
99 ):
100 self._expr: str = expr
101 deps: set[SfgVar] = set()
102 incls: set[HeaderFile] = set(includes) if includes is not None else set()
104 for obj in depends:
105 if isinstance(obj, AugExpr):
106 deps |= obj.depends
107 incls |= obj.includes
108 else:
109 deps.add(obj)
111 self._depends = frozenset(deps)
112 self._includes = frozenset(incls)
114 @property
115 def expr(self) -> str:
116 return self._expr
118 @property
119 def depends(self) -> frozenset[SfgVar]:
120 return self._depends
122 @property
123 def includes(self) -> frozenset[HeaderFile]:
124 return self._includes
126 def __hash_contents__(self):
127 return (self._expr, self._depends, self._includes)
129 def __eq__(self, other: object):
130 if not isinstance(other, DependentExpression):
131 return False
133 return self.__hash_contents__() == other.__hash_contents__()
135 def __hash__(self):
136 return hash(self.__hash_contents__())
138 def __str__(self) -> str:
139 return self.expr
141 def __add__(self, other: DependentExpression):
142 return DependentExpression(
143 self.expr + other.expr,
144 self.depends | other.depends,
145 self._includes | other._includes,
146 )
149class VarExpr(DependentExpression):
150 def __init__(self, var: SfgVar):
151 self._var = var
152 base_type = strip_ptr_ref(var.dtype)
153 incls: Iterable[HeaderFile]
154 match base_type:
155 case CppType():
156 incls = base_type.class_includes
157 case _:
158 incls = (
159 HeaderFile.parse(header) for header in var.dtype.required_headers
160 )
161 super().__init__(var.name, (var,), incls)
163 @property
164 def variable(self) -> SfgVar:
165 return self._var
168class AugExpr:
169 """C++ expression augmented with variable dependencies and a type-dependent interface.
171 `AugExpr` is the primary class for modelling C++ expressions in *pystencils-sfg*.
172 It stores both an expression's code string,
173 the set of variables (`SfgVar`) the expression depends on,
174 as well as any headers that must be included for the expression to be evaluated.
175 This dependency information is used by the composer and postprocessing system
176 to infer function parameter lists and automatic header inclusions.
178 **Construction and Binding**
180 Constructing an `AugExpr` is a two-step process comprising *construction* and *binding*.
181 An `AugExpr` can be constructed with our without an associated data type.
182 After construction, the `AugExpr` object is still *unbound*;
183 it does not yet hold any syntax.
185 Syntax binding can happen in two ways:
187 - Calling `var <AugExpr.var>` on an unbound `AugExpr` turns it into a *variable* with the given name.
188 This variable expression takes its set of required header files from the
189 `required_headers <pystencils.types.PsType.required_headers>` field of the data type of the `AugExpr`.
190 - Using `bind <AugExpr.bind>`, an unbound `AugExpr` can be bound to an arbitrary string
191 of code. The `bind` method mirrors the interface of `str.format` to combine sub-expressions
192 and collect their dependencies.
193 The `format <AugExpr.format>` static method is a wrapper around `bind` for expressions
194 without a type.
196 An `AugExpr` can be bound only once.
198 **C++ API Mirroring**
200 Subclasses of `AugExpr` can mimic C++ APIs by defining factory methods that
201 build expressions for C++ method calls, etc., from a list of argument expressions.
203 Args:
204 dtype: Optional, data type of this expression interface
205 """
207 __match_args__ = ("expr", "dtype")
209 def __init__(self, dtype: UserTypeSpec | None = None):
210 self._dtype = create_type(dtype) if dtype is not None else None
211 self._bound: DependentExpression | None = None
212 self._is_variable = False
214 def var(self, name: str):
215 """Bind an unbound `AugExpr` instance as a new variable of given name."""
216 v = SfgVar(name, self.get_dtype())
217 expr = VarExpr(v)
218 return self._bind(expr)
220 @staticmethod
221 def make(
222 code: str,
223 depends: Iterable[SfgVar | AugExpr],
224 dtype: UserTypeSpec | None = None,
225 ):
226 return AugExpr(dtype)._bind(DependentExpression(code, depends))
228 @staticmethod
229 def format(fmt: str, *deps, **kwdeps) -> AugExpr:
230 """Create a new `AugExpr` by combining existing expressions."""
231 return AugExpr().bind(fmt, *deps, **kwdeps)
233 def bind(
234 self,
235 fmt: str | AugExpr,
236 *deps,
237 require_headers: Iterable[str | HeaderFile] = (),
238 **kwdeps,
239 ):
240 """Bind an unbound `AugExpr` instance to an expression."""
241 if isinstance(fmt, AugExpr):
242 if bool(deps) or bool(kwdeps):
243 raise ValueError(
244 "Binding to another AugExpr does not permit additional arguments"
245 )
246 if fmt._bound is None:
247 raise ValueError("Cannot rebind to unbound AugExpr.")
248 self._bind(fmt._bound)
249 else:
250 dependencies: set[SfgVar] = set()
251 incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers)
253 from pystencils.sympyextensions import is_constant
255 for expr in chain(deps, kwdeps.values()):
256 if isinstance(expr, _ExprLike):
257 dependencies |= depends(expr)
258 incls |= includes(expr)
259 elif isinstance(expr, sp.Expr) and not is_constant(expr):
260 raise ValueError(
261 f"Cannot parse SymPy expression as C++ expression: {expr}\n"
262 " * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
263 "since they contain symbols without type information."
264 )
266 code = fmt.format(*deps, **kwdeps)
267 self._bind(DependentExpression(code, dependencies, incls))
268 return self
270 @property
271 def code(self) -> str:
272 if self._bound is None:
273 raise SfgException("No syntax bound to this AugExpr.")
274 return str(self._bound)
276 @property
277 def depends(self) -> frozenset[SfgVar]:
278 if self._bound is None:
279 raise SfgException("No syntax bound to this AugExpr.")
281 return self._bound.depends
283 @property
284 def includes(self) -> frozenset[HeaderFile]:
285 if self._bound is None:
286 raise SfgException("No syntax bound to this AugExpr.")
288 return self._bound.includes
290 @property
291 def dtype(self) -> PsType | None:
292 return self._dtype
294 def get_dtype(self) -> PsType:
295 if self._dtype is None:
296 raise SfgException("This AugExpr has no known data type.")
298 return self._dtype
300 @property
301 def is_variable(self) -> bool:
302 return isinstance(self._bound, VarExpr)
304 def as_variable(self) -> SfgVar:
305 if not isinstance(self._bound, VarExpr):
306 raise SfgException("This expression is not a variable")
307 return self._bound.variable
309 def __str__(self) -> str:
310 if self._bound is None:
311 return "/* [ERROR] unbound AugExpr */"
312 else:
313 return str(self._bound)
315 def __repr__(self) -> str:
316 return str(self)
318 def _bind(self, expr: DependentExpression):
319 if self._bound is not None:
320 raise SfgException("Attempting to bind an already-bound AugExpr.")
322 self._bound = expr
323 return self
325 def is_bound(self) -> bool:
326 return self._bound is not None
329class CppClass(AugExpr):
330 """Convenience base class for C++ API mirroring.
332 Example:
333 To reflect a C++ class (template) in pystencils-sfg, you may create a subclass
334 of `CppClass` like this:
336 >>> class MyClassTemplate(CppClass):
337 ... template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp")
340 Then use `AugExpr` initialization and binding to create variables or expressions with
341 this class:
343 >>> var = MyClassTemplate(T="float").var("myObj")
344 >>> var
345 myObj
347 >>> str(var.dtype).strip()
348 'mynamespace::MyClassTemplate< float >'
349 """
351 template: CppTypeFactory
353 def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs):
354 dtype = self.template(*args, **kwargs, const=const, ref=ref)
355 super().__init__(dtype)
357 def ctor_bind(self, *args):
358 fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}"
359 dtype = cast(CppType, self.get_dtype())
360 return self.bind(fstr, *args, require_headers=dtype.includes)
363def cppclass(
364 template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
365):
366 """
367 Convience class decorator for CppClass.
368 It adds to the decorated class the variable ``template`` via `cpptype`
369 and sets `CppClass` as a base clase.
371 >>> @cppclass("MyClass", "MyClass.hpp")
372 ... class MyClass:
373 ... pass
374 """
376 def wrapper(cls):
377 new_cls = type(cls.__name__, (cls, CppClass), {})
378 new_cls.template = cpptype(template_str, include)
379 return new_cls
381 return wrapper
384_VarLike = (AugExpr, SfgVar, TypedSymbol)
385VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
386"""Things that may act as a variable.
388Variable-like objects are entities from pystencils and pystencils-sfg that define
389a variable name and data type.
390Any `VarLike` object can be transformed into a canonical representation (i.e. `SfgVar`)
391using `asvar`.
392"""
395_ExprLike = (str, AugExpr, SfgVar, TypedSymbol)
396ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol
397"""Things that may act as a C++ expression.
399This type combines all objects that *pystencils-sfg* can handle in the place of C++
400expressions. These include all valid variable types (`VarLike`), plain strings, and
401complex expressions with variable dependency information (`AugExpr`).
403The set of variables an expression depends on can be determined using `depends`.
404"""
407def asvar(var: VarLike) -> SfgVar:
408 """Cast a variable-like object to its canonical representation,
410 Args:
411 var: Variable-like object
413 Returns:
414 SfgVar: Variable cast as `SfgVar`.
416 Raises:
417 ValueError: If given a non-variable `AugExpr`,
418 a `TypedSymbol <pystencils.TypedSymbol>`
419 with a `DynamicType <pystencils.sympyextensions.typed_sympy.DynamicType>`,
420 or any non-variable-like object.
421 """
422 match var:
423 case SfgVar():
424 return var
425 case AugExpr():
426 return var.as_variable()
427 case TypedSymbol():
428 from pystencils import DynamicType
430 if isinstance(var.dtype, DynamicType):
431 raise ValueError(
432 f"Unable to cast dynamically typed symbol {var} to a variable.\n"
433 f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel."
434 )
436 return SfgVar(var.name, var.dtype)
437 case _:
438 raise ValueError(f"Invalid variable: {var}")
441def depends(expr: ExprLike) -> set[SfgVar]:
442 """Determine the set of variables an expression depends on.
444 Args:
445 expr: Expression-like object to examine
447 Returns:
448 set[SfgVar]: Set of variables the expression depends on
450 Raises:
451 ValueError: If the argument was not a valid expression
452 """
454 match expr:
455 case None | str():
456 return set()
457 case SfgVar():
458 return {expr}
459 case TypedSymbol():
460 return {asvar(expr)}
461 case AugExpr():
462 return set(expr.depends)
463 case _:
464 raise ValueError(f"Invalid expression: {expr}")
467def includes(obj: ExprLike | PsType) -> set[HeaderFile]:
468 """Determine the set of header files an expression depends on.
470 Args:
471 expr: Expression-like object to examine
473 Returns:
474 set[HeaderFile]: Set of headers the expression depends on
476 Raises:
477 ValueError: If the argument was not a valid variable or expression
478 """
480 if isinstance(obj, PsType):
481 obj = strip_ptr_ref(obj)
483 match obj:
484 case CppType():
485 return set(obj.includes)
487 case PsType():
488 headers = set(HeaderFile.parse(h) for h in obj.required_headers)
489 if isinstance(obj, PsIntegerType):
490 headers.add(HeaderFile.parse("<cstdint>"))
491 return headers
493 case SfgVar(_, dtype):
494 return includes(dtype)
496 case TypedSymbol():
497 return includes(asvar(obj))
499 case str():
500 return set()
502 case AugExpr():
503 return set(obj.includes)
505 case _:
506 raise ValueError(f"Invalid expression: {obj}")