Coverage for src/pystencilssfg/ir/call_tree.py: 84%

230 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2from typing import TYPE_CHECKING, Sequence, Iterable, NewType 

3 

4from abc import ABC, abstractmethod 

5 

6from .entities import SfgKernelHandle 

7from ..lang import SfgVar, HeaderFile 

8 

9if TYPE_CHECKING: 

10 from ..config import CodeStyle 

11 

12 

13class SfgCallTreeNode(ABC): 

14 """Base class for all nodes comprising SFG call trees. 

15 

16 ## Code Printing 

17 

18 For extensibility, code printing is implemented inside the call tree. 

19 Therefore, every instantiable call tree node must implement the method `get_code`. 

20 By convention, the string returned by `get_code` should not contain a trailing newline. 

21 """ 

22 

23 def __init__(self) -> None: 

24 self._includes: set[HeaderFile] = set() 

25 

26 @property 

27 @abstractmethod 

28 def children(self) -> Sequence[SfgCallTreeNode]: 

29 """This node's children""" 

30 

31 @abstractmethod 

32 def get_code(self, cstyle: CodeStyle) -> str: 

33 """Returns the code of this node. 

34 

35 By convention, the code block emitted by this function should not contain a trailing newline. 

36 """ 

37 

38 @property 

39 def depends(self) -> set[SfgVar]: 

40 """Set of objects this leaf depends on""" 

41 return set() 

42 

43 @property 

44 def required_includes(self) -> set[HeaderFile]: 

45 """Return a set of header includes required by this node""" 

46 return self._includes 

47 

48 

49class SfgCallTreeLeaf(SfgCallTreeNode, ABC): 

50 """A leaf node of the call tree. 

51 

52 Leaf nodes must implement ``depends`` for automatic parameter collection. 

53 """ 

54 

55 def __init__(self): 

56 super().__init__() 

57 

58 @property 

59 def children(self) -> Sequence[SfgCallTreeNode]: 

60 return () 

61 

62 

63class SfgEmptyNode(SfgCallTreeLeaf): 

64 """A leaf node that does not emit any code. 

65 

66 Empty nodes must still implement ``depends``. 

67 """ 

68 

69 def __init__(self): 

70 super().__init__() 

71 

72 def get_code(self, cstyle: CodeStyle) -> str: 

73 return "" 

74 

75 

76class SfgStatements(SfgCallTreeLeaf): 

77 """Represents (a sequence of) statements in the source language. 

78 

79 This class groups together arbitrary code strings 

80 (e.g. sequences of C++ statements, cf. https://en.cppreference.com/w/cpp/language/statements), 

81 and annotates them with the set of symbols read and written by these statements. 

82 

83 It is the user's responsibility to ensure that the code string is valid code in the output language, 

84 and that the lists of required and defined objects are correct and complete. 

85 

86 Args: 

87 code_string: Code to be printed out. 

88 defined_params: Variables that will be newly defined and visible to code in sequence after these statements. 

89 required_params: Variables that are required as input to these statements. 

90 """ 

91 

92 def __init__( 

93 self, 

94 code_string: str, 

95 defines: Iterable[SfgVar], 

96 depends: Iterable[SfgVar], 

97 includes: Iterable[HeaderFile] = (), 

98 ): 

99 super().__init__() 

100 

101 self._code_string = code_string 

102 

103 self._defines = set(defines) 

104 self._depends = set(depends) 

105 self._includes = set(includes) 

106 

107 @property 

108 def depends(self) -> set[SfgVar]: 

109 return self._depends 

110 

111 @property 

112 def defines(self) -> set[SfgVar]: 

113 return self._defines 

114 

115 @property 

116 def code_string(self) -> str: 

117 return self._code_string 

118 

119 def get_code(self, cstyle: CodeStyle) -> str: 

120 return self._code_string 

121 

122 

123class SfgFunctionParams(SfgEmptyNode): 

124 def __init__(self, parameters: Sequence[SfgVar]): 

125 super().__init__() 

126 self._params = set(parameters) 

127 

128 @property 

129 def depends(self) -> set[SfgVar]: 

130 return self._params 

131 

132 

133class SfgRequireIncludes(SfgEmptyNode): 

134 def __init__(self, includes: Iterable[HeaderFile]): 

135 super().__init__() 

136 self._includes = set(includes) 

137 

138 @property 

139 def depends(self) -> set[SfgVar]: 

140 return set() 

141 

142 

143class SfgSequence(SfgCallTreeNode): 

144 __match_args__ = ("children",) 

145 

146 def __init__(self, children: Sequence[SfgCallTreeNode]): 

147 super().__init__() 

148 self._children = list(children) 

149 

150 @property 

151 def children(self) -> Sequence[SfgCallTreeNode]: 

152 return self._children 

153 

154 @children.setter 

155 def children(self, cs: Sequence[SfgCallTreeNode]): 

156 self._children = list(cs) 

157 

158 def __getitem__(self, idx: int) -> SfgCallTreeNode: 

159 return self._children[idx] 

160 

161 def __setitem__(self, idx: int, c: SfgCallTreeNode): 

162 self._children[idx] = c 

163 

164 def get_code(self, cstyle: CodeStyle) -> str: 

165 return "\n".join(c.get_code(cstyle) for c in self._children) 

166 

167 

168class SfgBlock(SfgCallTreeNode): 

169 def __init__(self, seq: SfgSequence): 

170 super().__init__() 

171 self._seq = seq 

172 

173 @property 

174 def sequence(self) -> SfgSequence: 

175 return self._seq 

176 

177 @property 

178 def children(self) -> Sequence[SfgCallTreeNode]: 

179 return (self._seq,) 

180 

181 def get_code(self, cstyle: CodeStyle) -> str: 

182 seq_code = cstyle.indent(self._seq.get_code(cstyle)) 

183 

184 return "{\n" + seq_code + "\n}" 

185 

186 

187# class SfgForLoop(SfgCallTreeNode): 

188# def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode): 

189# super().__init__(control_line, body) 

190 

191# @property 

192# def body(self) -> SfgStatements: 

193# return cast(SfgStatements) 

194 

195 

196class SfgKernelCallNode(SfgCallTreeLeaf): 

197 def __init__(self, kernel_handle: SfgKernelHandle): 

198 super().__init__() 

199 self._kernel_handle = kernel_handle 

200 

201 @property 

202 def depends(self) -> set[SfgVar]: 

203 return set(self._kernel_handle.parameters) 

204 

205 def get_code(self, cstyle: CodeStyle) -> str: 

206 kparams = self._kernel_handle.parameters 

207 fnc_name = self._kernel_handle.fqname 

208 call_parameters = ", ".join([p.name for p in kparams]) 

209 

210 return f"{fnc_name}({call_parameters});" 

211 

212 

213class SfgGpuKernelInvocation(SfgCallTreeNode): 

214 """A CUDA or HIP kernel invocation. 

215 

216 See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#execution-configuration 

217 or https://rocmdocs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#calling-global-functions 

218 for the syntax. 

219 """ 

220 

221 def __init__( 

222 self, 

223 kernel_handle: SfgKernelHandle, 

224 grid_size: SfgStatements, 

225 block_size: SfgStatements, 

226 shared_memory_bytes: SfgStatements | None, 

227 stream: SfgStatements | None, 

228 ): 

229 from pystencils.codegen import GpuKernel 

230 

231 kernel = kernel_handle.kernel 

232 if not isinstance(kernel, GpuKernel): 

233 raise ValueError( 

234 "An `SfgGpuKernelInvocation` node can only call GPU kernels." 

235 ) 

236 

237 super().__init__() 

238 self._kernel_handle = kernel_handle 

239 self._grid_size = grid_size 

240 self._block_size = block_size 

241 self._shared_memory_bytes = shared_memory_bytes 

242 self._stream = stream 

243 

244 @property 

245 def children(self) -> Sequence[SfgCallTreeNode]: 

246 return ( 

247 ( 

248 self._grid_size, 

249 self._block_size, 

250 ) 

251 + ( 

252 (self._shared_memory_bytes,) 

253 if self._shared_memory_bytes is not None 

254 else () 

255 ) 

256 + ((self._stream,) if self._stream is not None else ()) 

257 ) 

258 

259 @property 

260 def depends(self) -> set[SfgVar]: 

261 return set(self._kernel_handle.parameters) 

262 

263 def get_code(self, cstyle: CodeStyle) -> str: 

264 kparams = self._kernel_handle.parameters 

265 fnc_name = self._kernel_handle.fqname 

266 call_parameters = ", ".join([p.name for p in kparams]) 

267 

268 grid_args = [self._grid_size, self._block_size] 

269 if self._shared_memory_bytes is not None: 

270 grid_args += [self._shared_memory_bytes] 

271 

272 if self._stream is not None: 

273 grid_args += [self._stream] 

274 

275 grid = "<<< " + ", ".join(arg.get_code(cstyle) for arg in grid_args) + " >>>" 

276 return f"{fnc_name}{grid}({call_parameters});" 

277 

278 

279class SfgBranch(SfgCallTreeNode): 

280 def __init__( 

281 self, 

282 cond: SfgStatements, 

283 branch_true: SfgSequence, 

284 branch_false: SfgSequence | None = None, 

285 ): 

286 super().__init__() 

287 self._cond = cond 

288 self._branch_true = branch_true 

289 self._branch_false = branch_false 

290 

291 @property 

292 def condition(self) -> SfgStatements: 

293 return self._cond 

294 

295 @property 

296 def branch_true(self) -> SfgSequence: 

297 return self._branch_true 

298 

299 @property 

300 def branch_false(self) -> SfgSequence | None: 

301 return self._branch_false 

302 

303 @property 

304 def children(self) -> Sequence[SfgCallTreeNode]: 

305 return ( 

306 self._cond, 

307 self._branch_true, 

308 ) + ((self.branch_false,) if self.branch_false is not None else ()) 

309 

310 def get_code(self, cstyle: CodeStyle) -> str: 

311 code = f"if({self.condition.get_code(cstyle)}) { \n" 

312 code += cstyle.indent(self.branch_true.get_code(cstyle)) 

313 code += "\n}" 

314 

315 if self.branch_false is not None: 

316 code += "else {\n" 

317 code += cstyle.indent(self.branch_false.get_code(cstyle)) 

318 code += "\n}" 

319 

320 return code 

321 

322 

323class SfgSwitchCase(SfgCallTreeNode): 

324 DefaultCaseType = NewType("DefaultCaseType", object) 

325 """Sentinel type representing the ``default`` case.""" 

326 

327 Default = DefaultCaseType(object()) 

328 

329 def __init__(self, label: str | SfgSwitchCase.DefaultCaseType, body: SfgSequence): 

330 super().__init__() 

331 self._label = label 

332 self._body = body 

333 

334 @property 

335 def label(self) -> str | DefaultCaseType: 

336 return self._label 

337 

338 @property 

339 def body(self) -> SfgSequence: 

340 return self._body 

341 

342 @property 

343 def children(self) -> Sequence[SfgCallTreeNode]: 

344 return (self._body,) 

345 

346 @property 

347 def is_default(self) -> bool: 

348 return self._label == SfgSwitchCase.Default 

349 

350 def get_code(self, cstyle: CodeStyle) -> str: 

351 code = "" 

352 if self._label == SfgSwitchCase.Default: 

353 code += "default: {\n" 

354 else: 

355 code += f"case {self._label}: { \n" 

356 code += cstyle.indent(self.body.get_code(cstyle)) 

357 code += "\n}" 

358 return code 

359 

360 

361class SfgSwitch(SfgCallTreeNode): 

362 def __init__( 

363 self, 

364 switch_arg: SfgStatements, 

365 cases_dict: dict[str, SfgSequence], 

366 default: SfgSequence | None = None, 

367 ): 

368 super().__init__() 

369 self._cases = [SfgSwitchCase(label, body) for label, body in cases_dict.items()] 

370 if default is not None: 

371 # invariant: the default case is always the last child 

372 self._cases += [SfgSwitchCase(SfgSwitchCase.Default, default)] 

373 self._switch_arg = switch_arg 

374 self._default = ( 

375 SfgSwitchCase(SfgSwitchCase.Default, default) 

376 if default is not None 

377 else None 

378 ) 

379 

380 @property 

381 def switch_arg(self) -> str | SfgStatements: 

382 return self._switch_arg 

383 

384 @property 

385 def default(self) -> SfgCallTreeNode | None: 

386 return self._default 

387 

388 @property 

389 def children(self) -> tuple[SfgCallTreeNode, ...]: 

390 return (self._switch_arg,) + tuple(self._cases) 

391 

392 @property 

393 def cases(self) -> tuple[SfgCallTreeNode, ...]: 

394 if self._default is not None: 

395 return tuple(self._cases[:-1]) 

396 else: 

397 return tuple(self._cases) 

398 

399 @cases.setter 

400 def cases(self, cs: Sequence[SfgSwitchCase]) -> None: 

401 if len(cs) != len(self._cases): 

402 raise ValueError("The number of child nodes must remain the same!") 

403 

404 self._default = None 

405 for i, c in enumerate(cs): 

406 if c.is_default: 

407 if i != len(cs) - 1: 

408 raise ValueError("Default case must be listed last.") 

409 else: 

410 self._default = c 

411 

412 self._children = list(cs) 

413 

414 def set_case(self, idx: int, c: SfgSwitchCase): 

415 if c.is_default: 

416 if idx != len(self._children) - 1: 

417 raise ValueError("Default case must be the last child.") 

418 elif self._default is None: 

419 raise ValueError("Cannot replace normal case with default case.") 

420 else: 

421 self._default = c 

422 self._children[-1] = c 

423 else: 

424 self._children[idx] = c 

425 

426 def get_code(self, cstyle: CodeStyle) -> str: 

427 code = f"switch({self._switch_arg.get_code(cstyle)}) { \n" 

428 code += "\n".join(c.get_code(cstyle) for c in self._cases) 

429 code += "}" 

430 return code