Coverage for src/pystencilssfg/lang/expressions.py: 89%

237 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2 

3from typing import Iterable, TypeAlias, Any, cast 

4from itertools import chain 

5 

6import sympy as sp 

7 

8from pystencils import TypedSymbol 

9from pystencils.codegen import Parameter 

10from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type 

11 

12from ..exceptions import SfgException 

13from .headers import HeaderFile 

14from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype 

15 

16 

17class SfgVar: 

18 """C++ Variable. 

19 

20 Args: 

21 name: Name of the variable. Must be a valid C++ identifer. 

22 dtype: Data type of the variable. 

23 """ 

24 

25 __match_args__ = ("name", "dtype") 

26 

27 def __init__( 

28 self, 

29 name: str, 

30 dtype: UserTypeSpec, 

31 ): 

32 self._name = name 

33 self._dtype = create_type(dtype) 

34 

35 @property 

36 def name(self) -> str: 

37 return self._name 

38 

39 @property 

40 def dtype(self) -> PsType: 

41 return self._dtype 

42 

43 def _args(self) -> tuple[Any, ...]: 

44 return (self._name, self._dtype) 

45 

46 def __eq__(self, other: object) -> bool: 

47 if not isinstance(other, SfgVar): 

48 return False 

49 

50 return self._args() == other._args() 

51 

52 def __hash__(self) -> int: 

53 return hash(self._args()) 

54 

55 def name_and_type(self) -> str: 

56 return f"{self._name}: {self._dtype}" 

57 

58 def __str__(self) -> str: 

59 return self._name 

60 

61 def __repr__(self) -> str: 

62 return self.name_and_type() 

63 

64 

65class SfgKernelParamVar(SfgVar): 

66 __match_args__ = ("wrapped",) 

67 

68 """Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`.""" 

69 

70 def __init__(self, param: Parameter): 

71 self._param = param 

72 super().__init__(param.name, param.dtype) 

73 

74 @property 

75 def wrapped(self) -> Parameter: 

76 return self._param 

77 

78 def _args(self): 

79 return (self._param,) 

80 

81 

82class DependentExpression: 

83 """Wrapper around a C++ expression code string, 

84 annotated with a set of variables and a set of header files this expression depends on. 

85 

86 Args: 

87 expr: C++ Code string of the expression 

88 depends: Iterable of variables and/or `AugExpr` from which variable and header dependencies are collected 

89 includes: Iterable of header files which this expression additionally depends on 

90 """ 

91 

92 __match_args__ = ("expr", "depends") 

93 

94 def __init__( 

95 self, 

96 expr: str, 

97 depends: Iterable[SfgVar | AugExpr], 

98 includes: Iterable[HeaderFile] | None = None, 

99 ): 

100 self._expr: str = expr 

101 deps: set[SfgVar] = set() 

102 incls: set[HeaderFile] = set(includes) if includes is not None else set() 

103 

104 for obj in depends: 

105 if isinstance(obj, AugExpr): 

106 deps |= obj.depends 

107 incls |= obj.includes 

108 else: 

109 deps.add(obj) 

110 

111 self._depends = frozenset(deps) 

112 self._includes = frozenset(incls) 

113 

114 @property 

115 def expr(self) -> str: 

116 return self._expr 

117 

118 @property 

119 def depends(self) -> frozenset[SfgVar]: 

120 return self._depends 

121 

122 @property 

123 def includes(self) -> frozenset[HeaderFile]: 

124 return self._includes 

125 

126 def __hash_contents__(self): 

127 return (self._expr, self._depends, self._includes) 

128 

129 def __eq__(self, other: object): 

130 if not isinstance(other, DependentExpression): 

131 return False 

132 

133 return self.__hash_contents__() == other.__hash_contents__() 

134 

135 def __hash__(self): 

136 return hash(self.__hash_contents__()) 

137 

138 def __str__(self) -> str: 

139 return self.expr 

140 

141 def __add__(self, other: DependentExpression): 

142 return DependentExpression( 

143 self.expr + other.expr, 

144 self.depends | other.depends, 

145 self._includes | other._includes, 

146 ) 

147 

148 

149class VarExpr(DependentExpression): 

150 def __init__(self, var: SfgVar): 

151 self._var = var 

152 base_type = strip_ptr_ref(var.dtype) 

153 incls: Iterable[HeaderFile] 

154 match base_type: 

155 case CppType(): 

156 incls = base_type.class_includes 

157 case _: 

158 incls = ( 

159 HeaderFile.parse(header) for header in var.dtype.required_headers 

160 ) 

161 super().__init__(var.name, (var,), incls) 

162 

163 @property 

164 def variable(self) -> SfgVar: 

165 return self._var 

166 

167 

168class AugExpr: 

169 """C++ expression augmented with variable dependencies and a type-dependent interface. 

170 

171 `AugExpr` is the primary class for modelling C++ expressions in *pystencils-sfg*. 

172 It stores both an expression's code string, 

173 the set of variables (`SfgVar`) the expression depends on, 

174 as well as any headers that must be included for the expression to be evaluated. 

175 This dependency information is used by the composer and postprocessing system 

176 to infer function parameter lists and automatic header inclusions. 

177 

178 **Construction and Binding** 

179 

180 Constructing an `AugExpr` is a two-step process comprising *construction* and *binding*. 

181 An `AugExpr` can be constructed with our without an associated data type. 

182 After construction, the `AugExpr` object is still *unbound*; 

183 it does not yet hold any syntax. 

184 

185 Syntax binding can happen in two ways: 

186 

187 - Calling `var <AugExpr.var>` on an unbound `AugExpr` turns it into a *variable* with the given name. 

188 This variable expression takes its set of required header files from the 

189 `required_headers <pystencils.types.PsType.required_headers>` field of the data type of the `AugExpr`. 

190 - Using `bind <AugExpr.bind>`, an unbound `AugExpr` can be bound to an arbitrary string 

191 of code. The `bind` method mirrors the interface of `str.format` to combine sub-expressions 

192 and collect their dependencies. 

193 The `format <AugExpr.format>` static method is a wrapper around `bind` for expressions 

194 without a type. 

195 

196 An `AugExpr` can be bound only once. 

197 

198 **C++ API Mirroring** 

199 

200 Subclasses of `AugExpr` can mimic C++ APIs by defining factory methods that 

201 build expressions for C++ method calls, etc., from a list of argument expressions. 

202 

203 Args: 

204 dtype: Optional, data type of this expression interface 

205 """ 

206 

207 __match_args__ = ("expr", "dtype") 

208 

209 def __init__(self, dtype: UserTypeSpec | None = None): 

210 self._dtype = create_type(dtype) if dtype is not None else None 

211 self._bound: DependentExpression | None = None 

212 self._is_variable = False 

213 

214 def var(self, name: str): 

215 """Bind an unbound `AugExpr` instance as a new variable of given name.""" 

216 v = SfgVar(name, self.get_dtype()) 

217 expr = VarExpr(v) 

218 return self._bind(expr) 

219 

220 @staticmethod 

221 def make( 

222 code: str, 

223 depends: Iterable[SfgVar | AugExpr], 

224 dtype: UserTypeSpec | None = None, 

225 ): 

226 return AugExpr(dtype)._bind(DependentExpression(code, depends)) 

227 

228 @staticmethod 

229 def format(fmt: str, *deps, **kwdeps) -> AugExpr: 

230 """Create a new `AugExpr` by combining existing expressions.""" 

231 return AugExpr().bind(fmt, *deps, **kwdeps) 

232 

233 def bind( 

234 self, 

235 fmt: str | AugExpr, 

236 *deps, 

237 require_headers: Iterable[str | HeaderFile] = (), 

238 **kwdeps, 

239 ): 

240 """Bind an unbound `AugExpr` instance to an expression.""" 

241 if isinstance(fmt, AugExpr): 

242 if bool(deps) or bool(kwdeps): 

243 raise ValueError( 

244 "Binding to another AugExpr does not permit additional arguments" 

245 ) 

246 if fmt._bound is None: 

247 raise ValueError("Cannot rebind to unbound AugExpr.") 

248 self._bind(fmt._bound) 

249 else: 

250 dependencies: set[SfgVar] = set() 

251 incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers) 

252 

253 from pystencils.sympyextensions import is_constant 

254 

255 for expr in chain(deps, kwdeps.values()): 

256 if isinstance(expr, _ExprLike): 

257 dependencies |= depends(expr) 

258 incls |= includes(expr) 

259 elif isinstance(expr, sp.Expr) and not is_constant(expr): 

260 raise ValueError( 

261 f"Cannot parse SymPy expression as C++ expression: {expr}\n" 

262 " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " 

263 "since they contain symbols without type information." 

264 ) 

265 

266 code = fmt.format(*deps, **kwdeps) 

267 self._bind(DependentExpression(code, dependencies, incls)) 

268 return self 

269 

270 @property 

271 def code(self) -> str: 

272 if self._bound is None: 

273 raise SfgException("No syntax bound to this AugExpr.") 

274 return str(self._bound) 

275 

276 @property 

277 def depends(self) -> frozenset[SfgVar]: 

278 if self._bound is None: 

279 raise SfgException("No syntax bound to this AugExpr.") 

280 

281 return self._bound.depends 

282 

283 @property 

284 def includes(self) -> frozenset[HeaderFile]: 

285 if self._bound is None: 

286 raise SfgException("No syntax bound to this AugExpr.") 

287 

288 return self._bound.includes 

289 

290 @property 

291 def dtype(self) -> PsType | None: 

292 return self._dtype 

293 

294 def get_dtype(self) -> PsType: 

295 if self._dtype is None: 

296 raise SfgException("This AugExpr has no known data type.") 

297 

298 return self._dtype 

299 

300 @property 

301 def is_variable(self) -> bool: 

302 return isinstance(self._bound, VarExpr) 

303 

304 def as_variable(self) -> SfgVar: 

305 if not isinstance(self._bound, VarExpr): 

306 raise SfgException("This expression is not a variable") 

307 return self._bound.variable 

308 

309 def __str__(self) -> str: 

310 if self._bound is None: 

311 return "/* [ERROR] unbound AugExpr */" 

312 else: 

313 return str(self._bound) 

314 

315 def __repr__(self) -> str: 

316 return str(self) 

317 

318 def _bind(self, expr: DependentExpression): 

319 if self._bound is not None: 

320 raise SfgException("Attempting to bind an already-bound AugExpr.") 

321 

322 self._bound = expr 

323 return self 

324 

325 def is_bound(self) -> bool: 

326 return self._bound is not None 

327 

328 

329class CppClass(AugExpr): 

330 """Convenience base class for C++ API mirroring. 

331 

332 Example: 

333 To reflect a C++ class (template) in pystencils-sfg, you may create a subclass 

334 of `CppClass` like this: 

335 

336 >>> class MyClassTemplate(CppClass): 

337 ... template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp") 

338 

339 

340 Then use `AugExpr` initialization and binding to create variables or expressions with 

341 this class: 

342 

343 >>> var = MyClassTemplate(T="float").var("myObj") 

344 >>> var 

345 myObj 

346 

347 >>> str(var.dtype).strip() 

348 'mynamespace::MyClassTemplate< float >' 

349 """ 

350 

351 template: CppTypeFactory 

352 

353 def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs): 

354 dtype = self.template(*args, **kwargs, const=const, ref=ref) 

355 super().__init__(dtype) 

356 

357 def ctor_bind(self, *args): 

358 fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}" 

359 dtype = cast(CppType, self.get_dtype()) 

360 return self.bind(fstr, *args, require_headers=dtype.includes) 

361 

362 

363def cppclass( 

364 template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () 

365): 

366 """ 

367 Convience class decorator for CppClass. 

368 It adds to the decorated class the variable ``template`` via `cpptype` 

369 and sets `CppClass` as a base clase. 

370 

371 >>> @cppclass("MyClass", "MyClass.hpp") 

372 ... class MyClass: 

373 ... pass 

374 """ 

375 

376 def wrapper(cls): 

377 new_cls = type(cls.__name__, (cls, CppClass), {}) 

378 new_cls.template = cpptype(template_str, include) 

379 return new_cls 

380 

381 return wrapper 

382 

383 

384_VarLike = (AugExpr, SfgVar, TypedSymbol) 

385VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol 

386"""Things that may act as a variable. 

387 

388Variable-like objects are entities from pystencils and pystencils-sfg that define 

389a variable name and data type. 

390Any `VarLike` object can be transformed into a canonical representation (i.e. `SfgVar`) 

391using `asvar`. 

392""" 

393 

394 

395_ExprLike = (str, AugExpr, SfgVar, TypedSymbol) 

396ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol 

397"""Things that may act as a C++ expression. 

398 

399This type combines all objects that *pystencils-sfg* can handle in the place of C++ 

400expressions. These include all valid variable types (`VarLike`), plain strings, and 

401complex expressions with variable dependency information (`AugExpr`). 

402 

403The set of variables an expression depends on can be determined using `depends`. 

404""" 

405 

406 

407def asvar(var: VarLike) -> SfgVar: 

408 """Cast a variable-like object to its canonical representation, 

409 

410 Args: 

411 var: Variable-like object 

412 

413 Returns: 

414 SfgVar: Variable cast as `SfgVar`. 

415 

416 Raises: 

417 ValueError: If given a non-variable `AugExpr`, 

418 a `TypedSymbol <pystencils.TypedSymbol>` 

419 with a `DynamicType <pystencils.sympyextensions.typed_sympy.DynamicType>`, 

420 or any non-variable-like object. 

421 """ 

422 match var: 

423 case SfgVar(): 

424 return var 

425 case AugExpr(): 

426 return var.as_variable() 

427 case TypedSymbol(): 

428 from pystencils import DynamicType 

429 

430 if isinstance(var.dtype, DynamicType): 

431 raise ValueError( 

432 f"Unable to cast dynamically typed symbol {var} to a variable.\n" 

433 f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel." 

434 ) 

435 

436 return SfgVar(var.name, var.dtype) 

437 case _: 

438 raise ValueError(f"Invalid variable: {var}") 

439 

440 

441def depends(expr: ExprLike) -> set[SfgVar]: 

442 """Determine the set of variables an expression depends on. 

443 

444 Args: 

445 expr: Expression-like object to examine 

446 

447 Returns: 

448 set[SfgVar]: Set of variables the expression depends on 

449 

450 Raises: 

451 ValueError: If the argument was not a valid expression 

452 """ 

453 

454 match expr: 

455 case None | str(): 

456 return set() 

457 case SfgVar(): 

458 return {expr} 

459 case TypedSymbol(): 

460 return {asvar(expr)} 

461 case AugExpr(): 

462 return set(expr.depends) 

463 case _: 

464 raise ValueError(f"Invalid expression: {expr}") 

465 

466 

467def includes(obj: ExprLike | PsType) -> set[HeaderFile]: 

468 """Determine the set of header files an expression depends on. 

469 

470 Args: 

471 expr: Expression-like object to examine 

472 

473 Returns: 

474 set[HeaderFile]: Set of headers the expression depends on 

475 

476 Raises: 

477 ValueError: If the argument was not a valid variable or expression 

478 """ 

479 

480 if isinstance(obj, PsType): 

481 obj = strip_ptr_ref(obj) 

482 

483 match obj: 

484 case CppType(): 

485 return set(obj.includes) 

486 

487 case PsType(): 

488 headers = set(HeaderFile.parse(h) for h in obj.required_headers) 

489 if isinstance(obj, PsIntegerType): 

490 headers.add(HeaderFile.parse("<cstdint>")) 

491 return headers 

492 

493 case SfgVar(_, dtype): 

494 return includes(dtype) 

495 

496 case TypedSymbol(): 

497 return includes(asvar(obj)) 

498 

499 case str(): 

500 return set() 

501 

502 case AugExpr(): 

503 return set(obj.includes) 

504 

505 case _: 

506 raise ValueError(f"Invalid expression: {obj}")