Coverage for src/pystencilssfg/composer/basic_composer.py: 87%

270 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2 

3from typing import Sequence, TypeAlias 

4from abc import ABC, abstractmethod 

5import sympy as sp 

6from functools import reduce 

7from warnings import warn 

8 

9from pystencils import ( 

10 Field, 

11 CreateKernelConfig, 

12 create_kernel, 

13 Assignment, 

14 AssignmentCollection, 

15) 

16from pystencils.codegen import Kernel, Lambda 

17from pystencils.types import create_type, UserTypeSpec, PsType 

18 

19from ..context import SfgContext, SfgCursor 

20from .custom import CustomGenerator 

21from ..ir import ( 

22 SfgCallTreeNode, 

23 SfgKernelCallNode, 

24 SfgStatements, 

25 SfgFunctionParams, 

26 SfgRequireIncludes, 

27 SfgSequence, 

28 SfgBlock, 

29 SfgBranch, 

30 SfgSwitch, 

31) 

32from ..ir.postprocessing import ( 

33 SfgDeferredParamSetter, 

34 SfgDeferredFieldMapping, 

35 SfgDeferredVectorMapping, 

36) 

37from ..ir import ( 

38 SfgFunction, 

39 SfgKernelNamespace, 

40 SfgKernelHandle, 

41 SfgEntityDecl, 

42 SfgEntityDef, 

43 SfgNamespaceBlock, 

44) 

45from ..lang import ( 

46 VarLike, 

47 ExprLike, 

48 _VarLike, 

49 _ExprLike, 

50 asvar, 

51 depends, 

52 HeaderFile, 

53 includes, 

54 SfgVar, 

55 SfgKernelParamVar, 

56 AugExpr, 

57 SupportsFieldExtraction, 

58 SupportsVectorExtraction, 

59 void, 

60) 

61from ..exceptions import SfgException 

62 

63 

64class SfgIComposer(ABC): 

65 def __init__(self, ctx: SfgContext): 

66 self._ctx = ctx 

67 self._cursor = ctx.cursor 

68 

69 @property 

70 def context(self): 

71 return self._ctx 

72 

73 

74class SfgNodeBuilder(ABC): 

75 """Base class for node builders used by the composer""" 

76 

77 @abstractmethod 

78 def resolve(self) -> SfgCallTreeNode: 

79 pass 

80 

81 

82_SequencerArg = (tuple, ExprLike, SfgCallTreeNode, SfgNodeBuilder) 

83SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder 

84"""Valid arguments to `make_sequence` and any sequencer that uses it.""" 

85 

86 

87class KernelsAdder: 

88 """Handle on a kernel namespace that permits registering kernels.""" 

89 

90 def __init__(self, cursor: SfgCursor, knamespace: SfgKernelNamespace): 

91 self._cursor = cursor 

92 self._kernel_namespace = knamespace 

93 self._inline: bool = False 

94 self._loc: SfgNamespaceBlock | None = None 

95 

96 def inline(self) -> KernelsAdder: 

97 """Generate kernel definitions ``inline`` in the header file.""" 

98 self._inline = True 

99 return self 

100 

101 def add(self, kernel: Kernel, name: str | None = None): 

102 """Adds an existing pystencils AST to this namespace. 

103 If a name is specified, the AST's function name is changed.""" 

104 if name is None: 

105 kernel_name = kernel.name 

106 else: 

107 kernel_name = name 

108 

109 if self._kernel_namespace.find_kernel(kernel_name) is not None: 

110 raise ValueError( 

111 f"Duplicate kernels: A kernel called {kernel_name} already exists " 

112 f"in namespace {self._kernel_namespace.fqname}" 

113 ) 

114 

115 if name is not None: 

116 kernel.name = kernel_name 

117 

118 khandle = SfgKernelHandle( 

119 kernel_name, self._kernel_namespace, kernel, inline=self._inline 

120 ) 

121 self._kernel_namespace.add_kernel(khandle) 

122 

123 loc = self._get_loc() 

124 loc.elements.append(SfgEntityDef(khandle)) 

125 

126 for header in kernel.required_headers: 

127 hfile = HeaderFile.parse(header) 

128 if self._inline: 

129 self._cursor.context.header_file.includes.append(hfile) 

130 else: 

131 impl_file = self._cursor.context.impl_file 

132 assert impl_file is not None 

133 impl_file.includes.append(hfile) 

134 

135 return khandle 

136 

137 def create( 

138 self, 

139 assignments: Assignment | Sequence[Assignment] | AssignmentCollection, 

140 name: str | None = None, 

141 config: CreateKernelConfig | None = None, 

142 ): 

143 """Creates a new pystencils kernel from a list of assignments and a configuration. 

144 This is a wrapper around `create_kernel <pystencils.codegen.create_kernel>` 

145 with a subsequent call to `add`. 

146 """ 

147 if config is None: 

148 config = CreateKernelConfig() 

149 

150 if name is not None: 

151 if self._kernel_namespace.find_kernel(name) is not None: 

152 raise ValueError( 

153 f"Duplicate kernels: A kernel called {name} already exists " 

154 f"in namespace {self._kernel_namespace.fqname}" 

155 ) 

156 

157 config.function_name = name 

158 

159 kernel = create_kernel(assignments, config=config) 

160 return self.add(kernel) 

161 

162 def _get_loc(self) -> SfgNamespaceBlock: 

163 if self._loc is None: 

164 kns_block = SfgNamespaceBlock(self._kernel_namespace) 

165 

166 if self._inline: 

167 self._cursor.write_header(kns_block) 

168 else: 

169 self._cursor.write_impl(kns_block) 

170 

171 self._loc = kns_block 

172 return self._loc 

173 

174 

175class SfgBasicComposer(SfgIComposer): 

176 """Composer for basic source components, and base class for all composer mix-ins.""" 

177 

178 def __init__(self, sfg: SfgContext | SfgIComposer): 

179 ctx: SfgContext = sfg if isinstance(sfg, SfgContext) else sfg.context 

180 super().__init__(ctx) 

181 

182 def prelude(self, content: str, end: str = "\n"): 

183 """Append a string to the prelude comment, to be printed at the top of both generated files. 

184 

185 The string should not contain C/C++ comment delimiters, since these will be added automatically 

186 during code generation. 

187 

188 :Example: 

189 >>> sfg.prelude("This file was generated using pystencils-sfg; do not modify it directly!") 

190 

191 will appear in the generated files as 

192 

193 .. code-block:: C++ 

194 

195 /* 

196 * This file was generated using pystencils-sfg; do not modify it directly! 

197 */ 

198 

199 """ 

200 for f in self._ctx.files: 

201 if f.prelude is None: 

202 f.prelude = content + end 

203 else: 

204 f.prelude += content + end 

205 

206 def code(self, *code: str, impl: bool = False): 

207 """Add arbitrary lines of code to the generated header file. 

208 

209 :Example: 

210 

211 >>> sfg.code( 

212 ... "#define PI 3.14 // more than enough for engineers", 

213 ... "using namespace std;" 

214 ... ) 

215 

216 will appear as 

217 

218 .. code-block:: C++ 

219 

220 #define PI 3.14 // more than enough for engineers 

221 using namespace std; 

222 

223 Args: 

224 code: Sequence of code strings to be written to the output file 

225 impl: If `True`, write the code to the implementation file; otherwise, to the header file. 

226 """ 

227 for c in code: 

228 if impl: 

229 self._cursor.write_impl(c) 

230 else: 

231 self._cursor.write_header(c) 

232 

233 def define(self, *definitions: str): 

234 from warnings import warn 

235 

236 warn( 

237 "The `define` method of `SfgBasicComposer` is deprecated and will be removed in a future version." 

238 "Use `sfg.code()` instead.", 

239 FutureWarning, 

240 ) 

241 

242 self.code(*definitions) 

243 

244 def namespace(self, namespace: str): 

245 """Enter a new namespace block. 

246 

247 Calling `namespace` as a regular function will open a new namespace as a child of the 

248 currently active namespace; this new namespace will then become active instead. 

249 Using `namespace` as a context manager will instead activate the given namespace 

250 only for the length of the ``with`` block. 

251 

252 Args: 

253 namespace: Qualified name of the namespace 

254 

255 :Example: 

256 

257 The following calls will set the current namespace to ``outer::inner`` 

258 for the remaining code generation run: 

259 

260 .. code-block:: 

261 

262 sfg.namespace("outer") 

263 sfg.namespace("inner") 

264 

265 Subsequent calls to `namespace` can only create further nested namespaces. 

266 

267 To step back out of a namespace, `namespace` can also be used as a context manager: 

268 

269 .. code-block:: 

270 

271 with sfg.namespace("detail"): 

272 ... 

273 

274 This way, code generated inside the ``with`` region is placed in the ``detail`` namespace, 

275 and code after this block will again live in the enclosing namespace. 

276 

277 """ 

278 return self._cursor.enter_namespace(namespace) 

279 

280 def generate(self, generator: CustomGenerator): 

281 """Invoke a custom code generator with the underlying context.""" 

282 from .composer import SfgComposer 

283 

284 generator.generate(SfgComposer(self)) 

285 

286 @property 

287 def kernels(self) -> KernelsAdder: 

288 """The default kernel namespace. 

289 

290 Add kernels like:: 

291 

292 sfg.kernels.add(ast, "kernel_name") 

293 sfg.kernels.create(assignments, "kernel_name", config) 

294 """ 

295 return self.kernel_namespace("kernels") 

296 

297 def kernel_namespace(self, name: str) -> KernelsAdder: 

298 """Return a view on a kernel namespace in order to add kernels to it.""" 

299 kns = self._cursor.get_entity(name) 

300 if kns is None: 

301 kns = SfgKernelNamespace(name, self._cursor.current_namespace) 

302 self._cursor.add_entity(kns) 

303 elif not isinstance(kns, SfgKernelNamespace): 

304 raise ValueError( 

305 f"The existing entity {kns.fqname} is not a kernel namespace" 

306 ) 

307 

308 kadder = KernelsAdder(self._cursor, kns) 

309 if self._ctx.impl_file is None: 

310 kadder.inline() 

311 return kadder 

312 

313 def include(self, header: str | HeaderFile, private: bool = False): 

314 """Include a header file. 

315 

316 Args: 

317 header_file: Path to the header file. Enclose in ``<>`` for a system header. 

318 private: If ``True``, in header-implementation code generation, the header file is 

319 only included in the implementation file. 

320 

321 :Example: 

322 

323 >>> sfg.include("<vector>") 

324 >>> sfg.include("custom.h") 

325 

326 will be printed as 

327 

328 .. code-block:: C++ 

329 

330 #include <vector> 

331 #include "custom.h" 

332 """ 

333 header_file = HeaderFile.parse(header) 

334 

335 if private: 

336 if self._ctx.impl_file is None: 

337 raise ValueError( 

338 "Cannot emit a private include since no implementation file is being generated" 

339 ) 

340 self._ctx.impl_file.includes.append(header_file) 

341 else: 

342 self._ctx.header_file.includes.append(header_file) 

343 

344 def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle): 

345 """Create a function comprising just a single kernel call. 

346 

347 Args: 

348 ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST. 

349 """ 

350 if isinstance(kernel, Kernel): 

351 khandle = self.kernels.add(kernel, name) 

352 else: 

353 khandle = kernel 

354 

355 self.function(name)(self.call(khandle)) 

356 

357 def function( 

358 self, 

359 name: str, 

360 return_type: UserTypeSpec | None = None, 

361 ) -> SfgFunctionSequencer: 

362 """Add a function. 

363 

364 The syntax of this function adder uses a chain of two calls to mimic C++ syntax: 

365 

366 .. code-block:: Python 

367 

368 sfg.function("FunctionName")( 

369 # Function Body 

370 ) 

371 

372 The function body is constructed via sequencing (see `make_sequence`). 

373 """ 

374 seq = SfgFunctionSequencer(self._cursor, name) 

375 

376 if return_type is not None: 

377 warn( 

378 "The parameter `return_type` to `function()` is deprecated and will be removed by version 0.1. " 

379 "Use `.returns()` instead.", 

380 FutureWarning, 

381 ) 

382 seq.returns(return_type) 

383 

384 if self._ctx.impl_file is None: 

385 seq.inline() 

386 

387 return seq 

388 

389 def call(self, kernel_handle: SfgKernelHandle) -> SfgCallTreeNode: 

390 """Use inside a function body to directly call a kernel. 

391 

392 When using `call`, the given kernel will simply be called as a function. 

393 To invoke a GPU kernel on a specified launch grid, 

394 use `gpu_invoke <SfgGpuComposer.gpu_invoke>` instead. 

395 

396 Args: 

397 kernel_handle: Handle to a kernel previously added to some kernel namespace. 

398 """ 

399 return SfgKernelCallNode(kernel_handle) 

400 

401 def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: 

402 """Syntax sequencing. For details, see `make_sequence`""" 

403 return make_sequence(*args) 

404 

405 def params(self, *args: AugExpr) -> SfgFunctionParams: 

406 """Use inside a function body to add parameters to the function.""" 

407 return SfgFunctionParams([x.as_variable() for x in args]) 

408 

409 def require(self, *incls: str | HeaderFile) -> SfgRequireIncludes: 

410 """Use inside a function body to require the inclusion of headers.""" 

411 return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls)) 

412 

413 def var(self, name: str, dtype: UserTypeSpec) -> AugExpr: 

414 """Create a variable with given name and data type.""" 

415 return AugExpr(create_type(dtype)).var(name) 

416 

417 def vars(self, names: str, dtype: UserTypeSpec) -> tuple[AugExpr, ...]: 

418 """Create multiple variables with given names and the same data type. 

419 

420 Example: 

421 

422 >>> sfg.vars("x, y, z", "float32") 

423 (x, y, z) 

424 

425 """ 

426 varnames = names.split(",") 

427 return tuple(self.var(n.strip(), dtype) for n in varnames) 

428 

429 def init(self, lhs: VarLike): 

430 """Create a C++ in-place initialization. 

431 

432 Usage: 

433 

434 .. code-block:: Python 

435 

436 obj = sfg.var("obj", "SomeClass") 

437 sfg.init(obj)(arg1, arg2, arg3) 

438 

439 becomes 

440 

441 .. code-block:: C++ 

442 

443 SomeClass obj { arg1, arg2, arg3 }; 

444 """ 

445 lhs_var = asvar(lhs) 

446 

447 def parse_args(*args: ExprLike): 

448 args_str = ", ".join(str(arg) for arg in args) 

449 deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set()) 

450 incls: set[HeaderFile] = reduce(set.union, (includes(arg) for arg in args)) 

451 return SfgStatements( 

452 f"{lhs_var.dtype.c_string()} {lhs_var.name} { {args_str} } ;", 

453 (lhs_var,), 

454 deps, 

455 incls, 

456 ) 

457 

458 return parse_args 

459 

460 def expr(self, fmt: str, *deps, **kwdeps) -> AugExpr: 

461 """Create an expression while keeping track of variables it depends on. 

462 

463 This method is meant to be used similarly to `str.format`; in fact, 

464 it calls `str.format` internally and therefore supports all of its 

465 formatting features. 

466 In addition, however, the format arguments are scanned for *variables* 

467 (e.g. created using `var`), which are attached to the expression. 

468 This way, *pystencils-sfg* keeps track of any variables an expression depends on. 

469 

470 :Example: 

471 

472 >>> x, y, z, w = sfg.vars("x, y, z, w", "float32") 

473 >>> expr = sfg.expr("{} + {} * {}", x, y, z) 

474 >>> expr 

475 x + y * z 

476 

477 You can look at the expression's dependencies: 

478 

479 >>> sorted(expr.depends, key=lambda v: v.name) 

480 [x: float32, y: float32, z: float32] 

481 

482 If you use an existing expression to create a larger one, the new expression 

483 inherits all variables from its parts: 

484 

485 >>> expr2 = sfg.expr("{} + {}", expr, w) 

486 >>> expr2 

487 x + y * z + w 

488 >>> sorted(expr2.depends, key=lambda v: v.name) 

489 [w: float32, x: float32, y: float32, z: float32] 

490 

491 """ 

492 return AugExpr.format(fmt, *deps, **kwdeps) 

493 

494 def expr_from_lambda(self, lamb: Lambda) -> AugExpr: 

495 depends = set(SfgKernelParamVar(p) for p in lamb.parameters) 

496 code = lamb.c_code() 

497 return AugExpr.make(code, depends, dtype=lamb.return_type) 

498 

499 @property 

500 def branch(self) -> SfgBranchBuilder: 

501 """Use inside a function body to create an if/else conditonal branch. 

502 

503 The syntax is: 

504 

505 .. code-block:: Python 

506 

507 sfg.branch("condition")( 

508 # then-body 

509 )( 

510 # else-body (may be omitted) 

511 ) 

512 """ 

513 return SfgBranchBuilder() 

514 

515 def switch(self, switch_arg: ExprLike, autobreak: bool = True) -> SfgSwitchBuilder: 

516 """Use inside a function to construct a switch-case statement. 

517 

518 Args: 

519 switch_arg: Argument to the `switch()` statement 

520 autobreak: Whether to automatically print a ``break;`` at the end of each case block 

521 """ 

522 return SfgSwitchBuilder(switch_arg, autobreak=autobreak) 

523 

524 def map_field( 

525 self, 

526 field: Field, 

527 index_provider: SupportsFieldExtraction, 

528 cast_indexing_symbols: bool = True, 

529 ) -> SfgDeferredFieldMapping: 

530 """Map a pystencils field to a field data structure, from which pointers, sizes 

531 and strides should be extracted. 

532 

533 Args: 

534 field: The pystencils field to be mapped 

535 index_provider: An object that provides the field indexing information 

536 cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols 

537 """ 

538 return SfgDeferredFieldMapping( 

539 field, index_provider, cast_indexing_symbols=cast_indexing_symbols 

540 ) 

541 

542 def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): 

543 """Set a kernel parameter to an expression. 

544 

545 Code setting the parameter will only be generated if the parameter 

546 is actually alive (i.e. required by some kernel, and not yet set) at 

547 the point this method is called. 

548 """ 

549 var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param 

550 return SfgDeferredParamSetter(var, expr) 

551 

552 def map_vector( 

553 self, 

554 lhs_components: Sequence[VarLike | sp.Symbol], 

555 rhs: SupportsVectorExtraction, 

556 ): 

557 """Extracts scalar numerical values from a vector data type. 

558 

559 Args: 

560 lhs_components: Vector components as a list of symbols. 

561 rhs: An object providing access to vector components 

562 """ 

563 components: list[SfgVar | sp.Symbol] = [ 

564 (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components 

565 ] 

566 return SfgDeferredVectorMapping(components, rhs) 

567 

568 

569def make_statements(arg: ExprLike) -> SfgStatements: 

570 return SfgStatements(str(arg), (), depends(arg), includes(arg)) 

571 

572 

573def make_sequence(*args: SequencerArg) -> SfgSequence: 

574 """Construct a sequence of C++ code from various kinds of arguments. 

575 

576 `make_sequence` is ubiquitous throughout the function building front-end; 

577 among others, it powers the syntax of `SfgBasicComposer.function` 

578 and `SfgBasicComposer.branch`. 

579 

580 `make_sequence` constructs an abstract syntax tree for code within a function body, accepting various 

581 types of arguments which then get turned into C++ code. These are 

582 

583 - Strings (`str`) are printed as-is 

584 - Tuples (`tuple`) signify *blocks*, i.e. C++ code regions enclosed in ``{ }`` 

585 - Sub-ASTs and AST builders, which are often produced by the syntactic sugar and 

586 factory methods of `SfgComposer`. 

587 

588 :Example: 

589 

590 .. code-block:: Python 

591 

592 tree = make_sequence( 

593 "int a = 0;", 

594 "int b = 1;", 

595 ( 

596 "int tmp = b;", 

597 "b = a;", 

598 "a = tmp;" 

599 ), 

600 SfgKernelCall(kernel_handle) 

601 ) 

602 

603 sfg.context.add_function("myFunction", tree) 

604 

605 will translate to 

606 

607 .. code-block:: C++ 

608 

609 void myFunction() { 

610 int a = 0; 

611 int b = 0; 

612 { 

613 int tmp = b; 

614 b = a; 

615 a = tmp; 

616 } 

617 kernels::kernel( ... ); 

618 } 

619 """ 

620 children = [] 

621 for i, arg in enumerate(args): 

622 if isinstance(arg, SfgNodeBuilder): 

623 children.append(arg.resolve()) 

624 elif isinstance(arg, SfgCallTreeNode): 

625 children.append(arg) 

626 elif isinstance(arg, _ExprLike): 

627 children.append(make_statements(arg)) 

628 elif isinstance(arg, tuple): 

629 # Tuples are treated as blocks 

630 subseq = make_sequence(*arg) 

631 children.append(SfgBlock(subseq)) 

632 else: 

633 raise TypeError(f"Sequence argument {i} has invalid type.") 

634 

635 return SfgSequence(children) 

636 

637 

638class SfgFunctionSequencerBase: 

639 """Common base class for function and method sequencers. 

640 

641 This builder uses call sequencing to specify the function or method's properties. 

642 

643 Example: 

644 

645 >>> sfg.function( 

646 ... "myFunction" 

647 ... ).returns( 

648 ... "float32" 

649 ... ).attr( 

650 ... "nodiscard", "maybe_unused" 

651 ... ).inline().constexpr()( 

652 ... "return 31.2;" 

653 ... ) 

654 """ 

655 

656 def __init__(self, cursor: SfgCursor, name: str) -> None: 

657 self._cursor = cursor 

658 self._name = name 

659 self._return_type: PsType = void 

660 self._params: list[SfgVar] | None = None 

661 

662 # Qualifiers 

663 self._inline: bool = False 

664 self._constexpr: bool = False 

665 

666 # Attributes 

667 self._attributes: list[str] = [] 

668 

669 def returns(self, rtype: UserTypeSpec): 

670 """Set the return type of the function""" 

671 self._return_type = create_type(rtype) 

672 return self 

673 

674 def params(self, *args: VarLike): 

675 """Specify the parameters for this function. 

676 

677 Use this to manually specify the function's parameter list. 

678 

679 If any free variables collected from the function body are not contained 

680 in the parameter list, an error will be raised. 

681 """ 

682 self._params = [asvar(v) for v in args] 

683 return self 

684 

685 def inline(self): 

686 """Mark this function as ``inline``.""" 

687 self._inline = True 

688 return self 

689 

690 def constexpr(self): 

691 """Mark this function as ``constexpr``.""" 

692 self._constexpr = True 

693 return self 

694 

695 def attr(self, *attrs: str): 

696 """Add attributes to this function""" 

697 self._attributes += attrs 

698 return self 

699 

700 

701class SfgFunctionSequencer(SfgFunctionSequencerBase): 

702 """Sequencer for constructing functions.""" 

703 

704 def __call__(self, *args: SequencerArg) -> None: 

705 """Populate the function body""" 

706 tree = make_sequence(*args) 

707 func = SfgFunction( 

708 self._name, 

709 self._cursor.current_namespace, 

710 tree, 

711 return_type=self._return_type, 

712 inline=self._inline, 

713 constexpr=self._constexpr, 

714 attributes=self._attributes, 

715 required_params=self._params, 

716 ) 

717 self._cursor.add_entity(func) 

718 

719 if self._inline: 

720 self._cursor.write_header(SfgEntityDef(func)) 

721 else: 

722 self._cursor.write_header(SfgEntityDecl(func)) 

723 self._cursor.write_impl(SfgEntityDef(func)) 

724 

725 

726class SfgBranchBuilder(SfgNodeBuilder): 

727 """Multi-call builder for C++ ``if/else`` statements.""" 

728 

729 def __init__(self) -> None: 

730 self._phase = 0 

731 

732 self._cond: ExprLike | None = None 

733 self._branch_true = SfgSequence(()) 

734 self._branch_false: SfgSequence | None = None 

735 

736 def __call__(self, *args) -> SfgBranchBuilder: 

737 match self._phase: 

738 case 0: # Condition 

739 if len(args) != 1: 

740 raise ValueError( 

741 "Must specify exactly one argument as branch condition!" 

742 ) 

743 

744 self._cond = args[0] 

745 

746 case 1: # Then-branch 

747 self._branch_true = make_sequence(*args) 

748 case 2: # Else-branch 

749 self._branch_false = make_sequence(*args) 

750 case _: # There's no third branch! 

751 raise TypeError("Branch construct already complete.") 

752 

753 self._phase += 1 

754 

755 return self 

756 

757 def resolve(self) -> SfgCallTreeNode: 

758 assert self._cond is not None 

759 return SfgBranch( 

760 make_statements(self._cond), self._branch_true, self._branch_false 

761 ) 

762 

763 

764class SfgSwitchBuilder(SfgNodeBuilder): 

765 """Builder for C++ switches.""" 

766 

767 def __init__(self, switch_arg: ExprLike, autobreak: bool = True): 

768 self._switch_arg = switch_arg 

769 self._cases: dict[str, SfgSequence] = dict() 

770 self._default: SfgSequence | None = None 

771 self._autobreak = autobreak 

772 

773 def case(self, label: str): 

774 if label in self._cases: 

775 raise SfgException(f"Duplicate case: {label}") 

776 

777 def sequencer(*args: SequencerArg): 

778 if self._autobreak: 

779 args += ("break;",) 

780 tree = make_sequence(*args) 

781 self._cases[label] = tree 

782 return self 

783 

784 return sequencer 

785 

786 def cases(self, cases_dict: dict[str, SequencerArg]): 

787 for key, value in cases_dict.items(): 

788 self.case(key)(value) 

789 return self 

790 

791 def default(self, *args): 

792 if self._default is not None: 

793 raise SfgException("Duplicate default case") 

794 

795 tree = make_sequence(*args) 

796 self._default = tree 

797 

798 return self 

799 

800 def resolve(self) -> SfgCallTreeNode: 

801 return SfgSwitch(make_statements(self._switch_arg), self._cases, self._default)