Coverage for src/pystencilssfg/ir/entities.py: 87%

270 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2 

3from dataclasses import dataclass 

4from abc import ABC 

5from enum import Enum, auto 

6from typing import ( 

7 TYPE_CHECKING, 

8 Sequence, 

9 Generator, 

10) 

11from itertools import chain 

12 

13from pystencils import Field 

14from pystencils.codegen import Kernel 

15from pystencils.types import PsType, PsCustomType 

16 

17from ..lang import SfgVar, SfgKernelParamVar, void, ExprLike 

18from ..exceptions import SfgException 

19 

20if TYPE_CHECKING: 

21 from . import SfgCallTreeNode 

22 

23 

24# ========================================================================================================= 

25# 

26# SEMANTICAL ENTITIES 

27# 

28# These classes model *code entities*, which represent *semantic components* of the generated files. 

29# 

30# ========================================================================================================= 

31 

32 

33class SfgCodeEntity: 

34 """Base class for code entities. 

35 

36 Each code entity has a name and an optional enclosing namespace. 

37 """ 

38 

39 def __init__(self, name: str, parent_namespace: SfgNamespace) -> None: 

40 self._name = name 

41 self._namespace: SfgNamespace = parent_namespace 

42 

43 @property 

44 def name(self) -> str: 

45 """Name of this entity""" 

46 return self._name 

47 

48 @property 

49 def fqname(self) -> str: 

50 """Fully qualified name of this entity""" 

51 if not isinstance(self._namespace, SfgGlobalNamespace): 

52 return self._namespace.fqname + "::" + self._name 

53 else: 

54 return self._name 

55 

56 @property 

57 def parent_namespace(self) -> SfgNamespace | None: 

58 """Parent namespace of this entity""" 

59 return self._namespace 

60 

61 

62class SfgNamespace(SfgCodeEntity): 

63 """A C++ namespace. 

64 

65 Each namespace has a name and a parent; its fully qualified name is given as 

66 ``<parent.name>::<name>``. 

67 

68 Args: 

69 name: Local name of this namespace 

70 parent: Parent namespace enclosing this namespace 

71 """ 

72 

73 def __init__(self, name: str, parent_namespace: SfgNamespace) -> None: 

74 super().__init__(name, parent_namespace) 

75 

76 self._entities: dict[str, SfgCodeEntity] = dict() 

77 

78 def get_entity(self, qual_name: str) -> SfgCodeEntity | None: 

79 """Find an entity with the given qualified name within this namespace. 

80 

81 If ``qual_name`` contains any qualifying delimiters ``::``, 

82 each component but the last is interpreted as a namespace. 

83 """ 

84 tokens = qual_name.split("::", 1) 

85 match tokens: 

86 case [entity_name]: 

87 return self._entities.get(entity_name, None) 

88 case [nspace, remaining_qualname]: 

89 sub_nspace = self._entities.get(nspace, None) 

90 if sub_nspace is not None: 

91 if not isinstance(sub_nspace, SfgNamespace): 

92 raise KeyError( 

93 f"Unable to find entity {qual_name} in namespace {self._name}: " 

94 f"Entity {nspace} is not a namespace." 

95 ) 

96 return sub_nspace.get_entity(remaining_qualname) 

97 else: 

98 return None 

99 case _: 

100 assert False, "unreachable code" 

101 

102 def add_entity(self, entity: SfgCodeEntity): 

103 if entity.name in self._entities: 

104 raise ValueError( 

105 f"Another entity with the name {entity.fqname} already exists" 

106 ) 

107 self._entities[entity.name] = entity 

108 

109 def get_child_namespace(self, qual_name: str): 

110 if not qual_name: 

111 raise ValueError("Anonymous namespaces are not supported") 

112 

113 # Find the namespace by qualified lookup ... 

114 namespace = self.get_entity(qual_name) 

115 if namespace is not None: 

116 if not type(namespace) is SfgNamespace: 

117 raise ValueError(f"Entity {qual_name} exists, but is not a namespace") 

118 else: 

119 # ... or create it 

120 tokens = qual_name.split("::") 

121 namespace = self 

122 for tok in tokens: 

123 namespace = SfgNamespace(tok, namespace) 

124 

125 return namespace 

126 

127 

128class SfgGlobalNamespace(SfgNamespace): 

129 """The C++ global namespace.""" 

130 

131 def __init__(self) -> None: 

132 super().__init__("", self) 

133 

134 @property 

135 def fqname(self) -> str: 

136 return "" 

137 

138 

139class SfgKernelHandle(SfgCodeEntity): 

140 """Handle to a pystencils kernel.""" 

141 

142 __match_args__ = ("kernel", "parameters") 

143 

144 def __init__( 

145 self, 

146 name: str, 

147 namespace: SfgKernelNamespace, 

148 kernel: Kernel, 

149 inline: bool = False, 

150 ): 

151 super().__init__(name, namespace) 

152 

153 self._kernel = kernel 

154 self._parameters = [SfgKernelParamVar(p) for p in kernel.parameters] 

155 

156 self._inline: bool = inline 

157 

158 self._scalar_params: set[SfgVar] = set() 

159 self._fields: set[Field] = set() 

160 

161 for param in self._parameters: 

162 if param.wrapped.is_field_parameter: 

163 self._fields |= set(param.wrapped.fields) 

164 else: 

165 self._scalar_params.add(param) 

166 

167 @property 

168 def parameters(self) -> Sequence[SfgKernelParamVar]: 

169 """Parameters to this kernel""" 

170 return self._parameters 

171 

172 @property 

173 def scalar_parameters(self) -> set[SfgVar]: 

174 """Scalar parameters to this kernel""" 

175 return self._scalar_params 

176 

177 @property 

178 def fields(self): 

179 """Fields accessed by this kernel""" 

180 return self._fields 

181 

182 @property 

183 def kernel(self) -> Kernel: 

184 """Underlying pystencils kernel object""" 

185 return self._kernel 

186 

187 @property 

188 def inline(self) -> bool: 

189 return self._inline 

190 

191 

192class SfgKernelNamespace(SfgNamespace): 

193 """A namespace grouping together a number of kernels.""" 

194 

195 def __init__(self, name: str, parent: SfgNamespace): 

196 super().__init__(name, parent) 

197 self._kernels: dict[str, SfgKernelHandle] = dict() 

198 

199 @property 

200 def name(self): 

201 return self._name 

202 

203 @property 

204 def kernels(self) -> tuple[SfgKernelHandle, ...]: 

205 return tuple(self._kernels.values()) 

206 

207 def find_kernel(self, name: str) -> SfgKernelHandle | None: 

208 return self._kernels.get(name, None) 

209 

210 def add_kernel(self, kernel: SfgKernelHandle): 

211 if kernel.name in self._kernels: 

212 raise ValueError( 

213 f"Duplicate kernels: A kernel called {kernel.name} already exists " 

214 f"in namespace {self.fqname}" 

215 ) 

216 self._kernels[kernel.name] = kernel 

217 

218 

219@dataclass(frozen=True, match_args=False) 

220class CommonFunctionProperties: 

221 tree: SfgCallTreeNode 

222 parameters: tuple[SfgVar, ...] 

223 return_type: PsType 

224 inline: bool 

225 constexpr: bool 

226 attributes: Sequence[str] 

227 

228 @staticmethod 

229 def collect_params(tree: SfgCallTreeNode, required_params: Sequence[SfgVar] | None): 

230 from .postprocessing import CallTreePostProcessing 

231 

232 param_collector = CallTreePostProcessing() 

233 params_set = param_collector(tree).function_params 

234 

235 if required_params is not None: 

236 if not (params_set <= set(required_params)): 

237 extras = params_set - set(required_params) 

238 raise SfgException( 

239 "Extraenous function parameters: " 

240 f"Found free variables {extras} that were not listed in manually specified function parameters." 

241 ) 

242 parameters = tuple(required_params) 

243 else: 

244 parameters = tuple(sorted(params_set, key=lambda p: p.name)) 

245 

246 return parameters 

247 

248 

249class SfgFunction(SfgCodeEntity, CommonFunctionProperties): 

250 """A free function.""" 

251 

252 __match_args__ = ("name", "tree", "parameters", "return_type") # type: ignore 

253 

254 def __init__( 

255 self, 

256 name: str, 

257 namespace: SfgNamespace, 

258 tree: SfgCallTreeNode, 

259 return_type: PsType = void, 

260 inline: bool = False, 

261 constexpr: bool = False, 

262 attributes: Sequence[str] = (), 

263 required_params: Sequence[SfgVar] | None = None, 

264 ): 

265 super().__init__(name, namespace) 

266 

267 parameters = self.collect_params(tree, required_params) 

268 

269 CommonFunctionProperties.__init__( 

270 self, 

271 tree, 

272 parameters, 

273 return_type, 

274 inline, 

275 constexpr, 

276 attributes, 

277 ) 

278 

279 

280class SfgVisibility(Enum): 

281 """Visibility qualifiers of C++""" 

282 

283 DEFAULT = auto() 

284 PRIVATE = auto() 

285 PROTECTED = auto() 

286 PUBLIC = auto() 

287 

288 def __str__(self) -> str: 

289 match self: 

290 case SfgVisibility.DEFAULT: 

291 return "" 

292 case SfgVisibility.PRIVATE: 

293 return "private" 

294 case SfgVisibility.PROTECTED: 

295 return "protected" 

296 case SfgVisibility.PUBLIC: 

297 return "public" 

298 

299 

300class SfgClassKeyword(Enum): 

301 """Class keywords of C++""" 

302 

303 STRUCT = auto() 

304 CLASS = auto() 

305 

306 def __str__(self) -> str: 

307 match self: 

308 case SfgClassKeyword.STRUCT: 

309 return "struct" 

310 case SfgClassKeyword.CLASS: 

311 return "class" 

312 

313 

314class SfgClassMember(ABC): 

315 """Base class for class member entities""" 

316 

317 def __init__(self, cls: SfgClass) -> None: 

318 self._cls: SfgClass = cls 

319 self._visibility: SfgVisibility | None = None 

320 

321 @property 

322 def owning_class(self) -> SfgClass: 

323 if self._cls is None: 

324 raise SfgException(f"{self} is not bound to a class.") 

325 return self._cls 

326 

327 @property 

328 def visibility(self) -> SfgVisibility: 

329 if self._visibility is None: 

330 raise SfgException( 

331 f"{self} is not bound to a class and therefore has no visibility." 

332 ) 

333 return self._visibility 

334 

335 

336class SfgMemberVariable(SfgVar, SfgClassMember): 

337 """Variable that is a field of a class""" 

338 

339 def __init__( 

340 self, 

341 name: str, 

342 dtype: PsType, 

343 cls: SfgClass, 

344 default_init: tuple[ExprLike, ...] | None = None, 

345 ): 

346 SfgVar.__init__(self, name, dtype) 

347 SfgClassMember.__init__(self, cls) 

348 self._default_init = default_init 

349 

350 @property 

351 def default_init(self) -> tuple[ExprLike, ...] | None: 

352 return self._default_init 

353 

354 

355class SfgMethod(SfgClassMember, CommonFunctionProperties): 

356 """Instance method of a class""" 

357 

358 __match_args__ = ("name", "tree", "parameters", "return_type") # type: ignore 

359 

360 def __init__( 

361 self, 

362 name: str, 

363 cls: SfgClass, 

364 tree: SfgCallTreeNode, 

365 return_type: PsType = void, 

366 inline: bool = False, 

367 const: bool = False, 

368 static: bool = False, 

369 constexpr: bool = False, 

370 virtual: bool = False, 

371 override: bool = False, 

372 attributes: Sequence[str] = (), 

373 required_params: Sequence[SfgVar] | None = None, 

374 ): 

375 super().__init__(cls) 

376 

377 self._name = name 

378 self._static = static 

379 self._const = const 

380 self._virtual = virtual 

381 self._override = override 

382 

383 parameters = self.collect_params(tree, required_params) 

384 

385 CommonFunctionProperties.__init__( 

386 self, 

387 tree, 

388 parameters, 

389 return_type, 

390 inline, 

391 constexpr, 

392 attributes, 

393 ) 

394 

395 @property 

396 def name(self) -> str: 

397 return self._name 

398 

399 @property 

400 def static(self) -> bool: 

401 return self._static 

402 

403 @property 

404 def const(self) -> bool: 

405 return self._const 

406 

407 @property 

408 def virtual(self) -> bool: 

409 return self._virtual 

410 

411 @property 

412 def override(self) -> bool: 

413 return self._override 

414 

415 

416class SfgConstructor(SfgClassMember): 

417 """Constructor of a class""" 

418 

419 __match_args__ = ("owning_class", "parameters", "initializers", "body") 

420 

421 def __init__( 

422 self, 

423 cls: SfgClass, 

424 parameters: Sequence[SfgVar] = (), 

425 initializers: Sequence[tuple[SfgVar | str, tuple[ExprLike, ...]]] = (), 

426 body: str = "", 

427 ): 

428 super().__init__(cls) 

429 self._parameters = tuple(parameters) 

430 self._initializers = tuple(initializers) 

431 self._body = body 

432 

433 @property 

434 def parameters(self) -> tuple[SfgVar, ...]: 

435 return self._parameters 

436 

437 @property 

438 def initializers(self) -> tuple[tuple[SfgVar | str, tuple[ExprLike, ...]], ...]: 

439 return self._initializers 

440 

441 @property 

442 def body(self) -> str: 

443 return self._body 

444 

445 

446class SfgClass(SfgCodeEntity): 

447 """A C++ class.""" 

448 

449 __match_args__ = ("class_keyword", "name") 

450 

451 def __init__( 

452 self, 

453 name: str, 

454 namespace: SfgNamespace, 

455 class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS, 

456 bases: Sequence[str] = (), 

457 ): 

458 if isinstance(bases, str): 

459 raise ValueError("Base classes must be given as a sequence.") 

460 

461 super().__init__(name, namespace) 

462 

463 self._class_keyword = class_keyword 

464 self._bases_classes = tuple(bases) 

465 

466 self._constructors: list[SfgConstructor] = [] 

467 self._methods: list[SfgMethod] = [] 

468 self._member_vars: dict[str, SfgMemberVariable] = dict() 

469 

470 @property 

471 def src_type(self) -> PsType: 

472 # TODO: Use CppTypeFactory instead 

473 return PsCustomType(self._name) 

474 

475 @property 

476 def base_classes(self) -> tuple[str, ...]: 

477 return self._bases_classes 

478 

479 @property 

480 def class_keyword(self) -> SfgClassKeyword: 

481 return self._class_keyword 

482 

483 def members( 

484 self, visibility: SfgVisibility | None = None 

485 ) -> Generator[SfgClassMember, None, None]: 

486 if visibility is None: 

487 yield from chain( 

488 self._constructors, self._methods, self._member_vars.values() 

489 ) 

490 else: 

491 yield from filter(lambda m: m.visibility == visibility, self.members()) 

492 

493 def member_variables( 

494 self, visibility: SfgVisibility | None = None 

495 ) -> Generator[SfgMemberVariable, None, None]: 

496 if visibility is not None: 

497 yield from filter( 

498 lambda m: m.visibility == visibility, self._member_vars.values() 

499 ) 

500 else: 

501 yield from self._member_vars.values() 

502 

503 def constructors( 

504 self, visibility: SfgVisibility | None = None 

505 ) -> Generator[SfgConstructor, None, None]: 

506 if visibility is not None: 

507 yield from filter(lambda m: m.visibility == visibility, self._constructors) 

508 else: 

509 yield from self._constructors 

510 

511 def methods( 

512 self, visibility: SfgVisibility | None = None 

513 ) -> Generator[SfgMethod, None, None]: 

514 if visibility is not None: 

515 yield from filter(lambda m: m.visibility == visibility, self._methods) 

516 else: 

517 yield from self._methods 

518 

519 def add_member(self, member: SfgClassMember, vis: SfgVisibility): 

520 if isinstance(member, SfgConstructor): 

521 self._constructors.append(member) 

522 elif isinstance(member, SfgMemberVariable): 

523 self._add_member_variable(member) 

524 elif isinstance(member, SfgMethod): 

525 self._methods.append(member) 

526 else: 

527 raise SfgException(f"{member} is not a valid class member.") 

528 

529 def _add_member_variable(self, variable: SfgMemberVariable): 

530 if variable.name in self._member_vars: 

531 raise SfgException( 

532 f"Duplicate field name {variable.name} in class {self._name}" 

533 ) 

534 

535 self._member_vars[variable.name] = variable