Coverage for src/pystencilssfg/ir/call_tree.py: 84%
230 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
2from typing import TYPE_CHECKING, Sequence, Iterable, NewType
4from abc import ABC, abstractmethod
6from .entities import SfgKernelHandle
7from ..lang import SfgVar, HeaderFile
9if TYPE_CHECKING:
10 from ..config import CodeStyle
13class SfgCallTreeNode(ABC):
14 """Base class for all nodes comprising SFG call trees.
16 ## Code Printing
18 For extensibility, code printing is implemented inside the call tree.
19 Therefore, every instantiable call tree node must implement the method `get_code`.
20 By convention, the string returned by `get_code` should not contain a trailing newline.
21 """
23 def __init__(self) -> None:
24 self._includes: set[HeaderFile] = set()
26 @property
27 @abstractmethod
28 def children(self) -> Sequence[SfgCallTreeNode]:
29 """This node's children"""
31 @abstractmethod
32 def get_code(self, cstyle: CodeStyle) -> str:
33 """Returns the code of this node.
35 By convention, the code block emitted by this function should not contain a trailing newline.
36 """
38 @property
39 def depends(self) -> set[SfgVar]:
40 """Set of objects this leaf depends on"""
41 return set()
43 @property
44 def required_includes(self) -> set[HeaderFile]:
45 """Return a set of header includes required by this node"""
46 return self._includes
49class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
50 """A leaf node of the call tree.
52 Leaf nodes must implement ``depends`` for automatic parameter collection.
53 """
55 def __init__(self):
56 super().__init__()
58 @property
59 def children(self) -> Sequence[SfgCallTreeNode]:
60 return ()
63class SfgEmptyNode(SfgCallTreeLeaf):
64 """A leaf node that does not emit any code.
66 Empty nodes must still implement ``depends``.
67 """
69 def __init__(self):
70 super().__init__()
72 def get_code(self, cstyle: CodeStyle) -> str:
73 return ""
76class SfgStatements(SfgCallTreeLeaf):
77 """Represents (a sequence of) statements in the source language.
79 This class groups together arbitrary code strings
80 (e.g. sequences of C++ statements, cf. https://en.cppreference.com/w/cpp/language/statements),
81 and annotates them with the set of symbols read and written by these statements.
83 It is the user's responsibility to ensure that the code string is valid code in the output language,
84 and that the lists of required and defined objects are correct and complete.
86 Args:
87 code_string: Code to be printed out.
88 defined_params: Variables that will be newly defined and visible to code in sequence after these statements.
89 required_params: Variables that are required as input to these statements.
90 """
92 def __init__(
93 self,
94 code_string: str,
95 defines: Iterable[SfgVar],
96 depends: Iterable[SfgVar],
97 includes: Iterable[HeaderFile] = (),
98 ):
99 super().__init__()
101 self._code_string = code_string
103 self._defines = set(defines)
104 self._depends = set(depends)
105 self._includes = set(includes)
107 @property
108 def depends(self) -> set[SfgVar]:
109 return self._depends
111 @property
112 def defines(self) -> set[SfgVar]:
113 return self._defines
115 @property
116 def code_string(self) -> str:
117 return self._code_string
119 def get_code(self, cstyle: CodeStyle) -> str:
120 return self._code_string
123class SfgFunctionParams(SfgEmptyNode):
124 def __init__(self, parameters: Sequence[SfgVar]):
125 super().__init__()
126 self._params = set(parameters)
128 @property
129 def depends(self) -> set[SfgVar]:
130 return self._params
133class SfgRequireIncludes(SfgEmptyNode):
134 def __init__(self, includes: Iterable[HeaderFile]):
135 super().__init__()
136 self._includes = set(includes)
138 @property
139 def depends(self) -> set[SfgVar]:
140 return set()
143class SfgSequence(SfgCallTreeNode):
144 __match_args__ = ("children",)
146 def __init__(self, children: Sequence[SfgCallTreeNode]):
147 super().__init__()
148 self._children = list(children)
150 @property
151 def children(self) -> Sequence[SfgCallTreeNode]:
152 return self._children
154 @children.setter
155 def children(self, cs: Sequence[SfgCallTreeNode]):
156 self._children = list(cs)
158 def __getitem__(self, idx: int) -> SfgCallTreeNode:
159 return self._children[idx]
161 def __setitem__(self, idx: int, c: SfgCallTreeNode):
162 self._children[idx] = c
164 def get_code(self, cstyle: CodeStyle) -> str:
165 return "\n".join(c.get_code(cstyle) for c in self._children)
168class SfgBlock(SfgCallTreeNode):
169 def __init__(self, seq: SfgSequence):
170 super().__init__()
171 self._seq = seq
173 @property
174 def sequence(self) -> SfgSequence:
175 return self._seq
177 @property
178 def children(self) -> Sequence[SfgCallTreeNode]:
179 return (self._seq,)
181 def get_code(self, cstyle: CodeStyle) -> str:
182 seq_code = cstyle.indent(self._seq.get_code(cstyle))
184 return "{\n" + seq_code + "\n}"
187# class SfgForLoop(SfgCallTreeNode):
188# def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode):
189# super().__init__(control_line, body)
191# @property
192# def body(self) -> SfgStatements:
193# return cast(SfgStatements)
196class SfgKernelCallNode(SfgCallTreeLeaf):
197 def __init__(self, kernel_handle: SfgKernelHandle):
198 super().__init__()
199 self._kernel_handle = kernel_handle
201 @property
202 def depends(self) -> set[SfgVar]:
203 return set(self._kernel_handle.parameters)
205 def get_code(self, cstyle: CodeStyle) -> str:
206 kparams = self._kernel_handle.parameters
207 fnc_name = self._kernel_handle.fqname
208 call_parameters = ", ".join([p.name for p in kparams])
210 return f"{fnc_name}({call_parameters});"
213class SfgGpuKernelInvocation(SfgCallTreeNode):
214 """A CUDA or HIP kernel invocation.
216 See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#execution-configuration
217 or https://rocmdocs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#calling-global-functions
218 for the syntax.
219 """
221 def __init__(
222 self,
223 kernel_handle: SfgKernelHandle,
224 grid_size: SfgStatements,
225 block_size: SfgStatements,
226 shared_memory_bytes: SfgStatements | None,
227 stream: SfgStatements | None,
228 ):
229 from pystencils.codegen import GpuKernel
231 kernel = kernel_handle.kernel
232 if not isinstance(kernel, GpuKernel):
233 raise ValueError(
234 "An `SfgGpuKernelInvocation` node can only call GPU kernels."
235 )
237 super().__init__()
238 self._kernel_handle = kernel_handle
239 self._grid_size = grid_size
240 self._block_size = block_size
241 self._shared_memory_bytes = shared_memory_bytes
242 self._stream = stream
244 @property
245 def children(self) -> Sequence[SfgCallTreeNode]:
246 return (
247 (
248 self._grid_size,
249 self._block_size,
250 )
251 + (
252 (self._shared_memory_bytes,)
253 if self._shared_memory_bytes is not None
254 else ()
255 )
256 + ((self._stream,) if self._stream is not None else ())
257 )
259 @property
260 def depends(self) -> set[SfgVar]:
261 return set(self._kernel_handle.parameters)
263 def get_code(self, cstyle: CodeStyle) -> str:
264 kparams = self._kernel_handle.parameters
265 fnc_name = self._kernel_handle.fqname
266 call_parameters = ", ".join([p.name for p in kparams])
268 grid_args = [self._grid_size, self._block_size]
269 if self._shared_memory_bytes is not None:
270 grid_args += [self._shared_memory_bytes]
272 if self._stream is not None:
273 grid_args += [self._stream]
275 grid = "<<< " + ", ".join(arg.get_code(cstyle) for arg in grid_args) + " >>>"
276 return f"{fnc_name}{grid}({call_parameters});"
279class SfgBranch(SfgCallTreeNode):
280 def __init__(
281 self,
282 cond: SfgStatements,
283 branch_true: SfgSequence,
284 branch_false: SfgSequence | None = None,
285 ):
286 super().__init__()
287 self._cond = cond
288 self._branch_true = branch_true
289 self._branch_false = branch_false
291 @property
292 def condition(self) -> SfgStatements:
293 return self._cond
295 @property
296 def branch_true(self) -> SfgSequence:
297 return self._branch_true
299 @property
300 def branch_false(self) -> SfgSequence | None:
301 return self._branch_false
303 @property
304 def children(self) -> Sequence[SfgCallTreeNode]:
305 return (
306 self._cond,
307 self._branch_true,
308 ) + ((self.branch_false,) if self.branch_false is not None else ())
310 def get_code(self, cstyle: CodeStyle) -> str:
311 code = f"if({self.condition.get_code(cstyle)}) { \n"
312 code += cstyle.indent(self.branch_true.get_code(cstyle))
313 code += "\n}"
315 if self.branch_false is not None:
316 code += "else {\n"
317 code += cstyle.indent(self.branch_false.get_code(cstyle))
318 code += "\n}"
320 return code
323class SfgSwitchCase(SfgCallTreeNode):
324 DefaultCaseType = NewType("DefaultCaseType", object)
325 """Sentinel type representing the ``default`` case."""
327 Default = DefaultCaseType(object())
329 def __init__(self, label: str | SfgSwitchCase.DefaultCaseType, body: SfgSequence):
330 super().__init__()
331 self._label = label
332 self._body = body
334 @property
335 def label(self) -> str | DefaultCaseType:
336 return self._label
338 @property
339 def body(self) -> SfgSequence:
340 return self._body
342 @property
343 def children(self) -> Sequence[SfgCallTreeNode]:
344 return (self._body,)
346 @property
347 def is_default(self) -> bool:
348 return self._label == SfgSwitchCase.Default
350 def get_code(self, cstyle: CodeStyle) -> str:
351 code = ""
352 if self._label == SfgSwitchCase.Default:
353 code += "default: {\n"
354 else:
355 code += f"case {self._label}: { \n"
356 code += cstyle.indent(self.body.get_code(cstyle))
357 code += "\n}"
358 return code
361class SfgSwitch(SfgCallTreeNode):
362 def __init__(
363 self,
364 switch_arg: SfgStatements,
365 cases_dict: dict[str, SfgSequence],
366 default: SfgSequence | None = None,
367 ):
368 super().__init__()
369 self._cases = [SfgSwitchCase(label, body) for label, body in cases_dict.items()]
370 if default is not None:
371 # invariant: the default case is always the last child
372 self._cases += [SfgSwitchCase(SfgSwitchCase.Default, default)]
373 self._switch_arg = switch_arg
374 self._default = (
375 SfgSwitchCase(SfgSwitchCase.Default, default)
376 if default is not None
377 else None
378 )
380 @property
381 def switch_arg(self) -> str | SfgStatements:
382 return self._switch_arg
384 @property
385 def default(self) -> SfgCallTreeNode | None:
386 return self._default
388 @property
389 def children(self) -> tuple[SfgCallTreeNode, ...]:
390 return (self._switch_arg,) + tuple(self._cases)
392 @property
393 def cases(self) -> tuple[SfgCallTreeNode, ...]:
394 if self._default is not None:
395 return tuple(self._cases[:-1])
396 else:
397 return tuple(self._cases)
399 @cases.setter
400 def cases(self, cs: Sequence[SfgSwitchCase]) -> None:
401 if len(cs) != len(self._cases):
402 raise ValueError("The number of child nodes must remain the same!")
404 self._default = None
405 for i, c in enumerate(cs):
406 if c.is_default:
407 if i != len(cs) - 1:
408 raise ValueError("Default case must be listed last.")
409 else:
410 self._default = c
412 self._children = list(cs)
414 def set_case(self, idx: int, c: SfgSwitchCase):
415 if c.is_default:
416 if idx != len(self._children) - 1:
417 raise ValueError("Default case must be the last child.")
418 elif self._default is None:
419 raise ValueError("Cannot replace normal case with default case.")
420 else:
421 self._default = c
422 self._children[-1] = c
423 else:
424 self._children[idx] = c
426 def get_code(self, cstyle: CodeStyle) -> str:
427 code = f"switch({self._switch_arg.get_code(cstyle)}) { \n"
428 code += "\n".join(c.get_code(cstyle) for c in self._cases)
429 code += "}"
430 return code