-
Frederik Hennig authoredFrederik Hennig authored
call_tree.py 12.36 KiB
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence, Iterable, NewType
from abc import ABC, abstractmethod
from .entities import SfgKernelHandle
from ..lang import SfgVar, HeaderFile
if TYPE_CHECKING:
from ..config import CodeStyle
class SfgCallTreeNode(ABC):
"""Base class for all nodes comprising SFG call trees.
## Code Printing
For extensibility, code printing is implemented inside the call tree.
Therefore, every instantiable call tree node must implement the method `get_code`.
By convention, the string returned by `get_code` should not contain a trailing newline.
"""
def __init__(self) -> None:
self._includes: set[HeaderFile] = set()
@property
@abstractmethod
def children(self) -> Sequence[SfgCallTreeNode]:
"""This node's children"""
@abstractmethod
def get_code(self, cstyle: CodeStyle) -> str:
"""Returns the code of this node.
By convention, the code block emitted by this function should not contain a trailing newline.
"""
@property
def depends(self) -> set[SfgVar]:
"""Set of objects this leaf depends on"""
return set()
@property
def required_includes(self) -> set[HeaderFile]:
"""Return a set of header includes required by this node"""
return self._includes
class SfgCallTreeLeaf(SfgCallTreeNode, ABC):
"""A leaf node of the call tree.
Leaf nodes must implement ``depends`` for automatic parameter collection.
"""
def __init__(self):
super().__init__()
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return ()
class SfgEmptyNode(SfgCallTreeLeaf):
"""A leaf node that does not emit any code.
Empty nodes must still implement ``depends``.
"""
def __init__(self):
super().__init__()
def get_code(self, cstyle: CodeStyle) -> str:
return ""
class SfgStatements(SfgCallTreeLeaf):
"""Represents (a sequence of) statements in the source language.
This class groups together arbitrary code strings
(e.g. sequences of C++ statements, cf. https://en.cppreference.com/w/cpp/language/statements),
and annotates them with the set of symbols read and written by these statements.
It is the user's responsibility to ensure that the code string is valid code in the output language,
and that the lists of required and defined objects are correct and complete.
Args:
code_string: Code to be printed out.
defined_params: Variables that will be newly defined and visible to code in sequence after these statements.
required_params: Variables that are required as input to these statements.
"""
def __init__(
self,
code_string: str,
defines: Iterable[SfgVar],
depends: Iterable[SfgVar],
includes: Iterable[HeaderFile] = (),
):
super().__init__()
self._code_string = code_string
self._defines = set(defines)
self._depends = set(depends)
self._includes = set(includes)
@property
def depends(self) -> set[SfgVar]:
return self._depends
@property
def defines(self) -> set[SfgVar]:
return self._defines
@property
def code_string(self) -> str:
return self._code_string
def get_code(self, cstyle: CodeStyle) -> str:
return self._code_string
class SfgFunctionParams(SfgEmptyNode):
def __init__(self, parameters: Sequence[SfgVar]):
super().__init__()
self._params = set(parameters)
@property
def depends(self) -> set[SfgVar]:
return self._params
class SfgRequireIncludes(SfgEmptyNode):
def __init__(self, includes: Iterable[HeaderFile]):
super().__init__()
self._includes = set(includes)
@property
def depends(self) -> set[SfgVar]:
return set()
class SfgSequence(SfgCallTreeNode):
__match_args__ = ("children",)
def __init__(self, children: Sequence[SfgCallTreeNode]):
super().__init__()
self._children = list(children)
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return self._children
@children.setter
def children(self, cs: Sequence[SfgCallTreeNode]):
self._children = list(cs)
def __getitem__(self, idx: int) -> SfgCallTreeNode:
return self._children[idx]
def __setitem__(self, idx: int, c: SfgCallTreeNode):
self._children[idx] = c
def get_code(self, cstyle: CodeStyle) -> str:
return "\n".join(c.get_code(cstyle) for c in self._children)
class SfgBlock(SfgCallTreeNode):
def __init__(self, seq: SfgSequence):
super().__init__()
self._seq = seq
@property
def sequence(self) -> SfgSequence:
return self._seq
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return (self._seq,)
def get_code(self, cstyle: CodeStyle) -> str:
seq_code = cstyle.indent(self._seq.get_code(cstyle))
return "{\n" + seq_code + "\n}"
# class SfgForLoop(SfgCallTreeNode):
# def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode):
# super().__init__(control_line, body)
# @property
# def body(self) -> SfgStatements:
# return cast(SfgStatements)
class SfgKernelCallNode(SfgCallTreeLeaf):
def __init__(self, kernel_handle: SfgKernelHandle):
super().__init__()
self._kernel_handle = kernel_handle
@property
def depends(self) -> set[SfgVar]:
return set(self._kernel_handle.parameters)
def get_code(self, cstyle: CodeStyle) -> str:
kparams = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fqname
call_parameters = ", ".join([p.name for p in kparams])
return f"{fnc_name}({call_parameters});"
class SfgGpuKernelInvocation(SfgCallTreeNode):
"""A CUDA or HIP kernel invocation.
See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#execution-configuration
or https://rocmdocs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#calling-global-functions
for the syntax.
"""
def __init__(
self,
kernel_handle: SfgKernelHandle,
grid_size: SfgStatements,
block_size: SfgStatements,
shared_memory_bytes: SfgStatements | None,
stream: SfgStatements | None,
):
from pystencils.codegen import GpuKernel
kernel = kernel_handle.kernel
if not isinstance(kernel, GpuKernel):
raise ValueError(
"An `SfgGpuKernelInvocation` node can only call GPU kernels."
)
super().__init__()
self._kernel_handle = kernel_handle
self._grid_size = grid_size
self._block_size = block_size
self._shared_memory_bytes = shared_memory_bytes
self._stream = stream
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return (
(
self._grid_size,
self._block_size,
)
+ (
(self._shared_memory_bytes,)
if self._shared_memory_bytes is not None
else ()
)
+ ((self._stream,) if self._stream is not None else ())
)
@property
def depends(self) -> set[SfgVar]:
return set(self._kernel_handle.parameters)
def get_code(self, cstyle: CodeStyle) -> str:
kparams = self._kernel_handle.parameters
fnc_name = self._kernel_handle.fqname
call_parameters = ", ".join([p.name for p in kparams])
grid_args = [self._grid_size, self._block_size]
if self._shared_memory_bytes is not None:
grid_args += [self._shared_memory_bytes]
if self._stream is not None:
grid_args += [self._stream]
grid = "<<< " + ", ".join(arg.get_code(cstyle) for arg in grid_args) + " >>>"
return f"{fnc_name}{grid}({call_parameters});"
class SfgBranch(SfgCallTreeNode):
def __init__(
self,
cond: SfgStatements,
branch_true: SfgSequence,
branch_false: SfgSequence | None = None,
):
super().__init__()
self._cond = cond
self._branch_true = branch_true
self._branch_false = branch_false
@property
def condition(self) -> SfgStatements:
return self._cond
@property
def branch_true(self) -> SfgSequence:
return self._branch_true
@property
def branch_false(self) -> SfgSequence | None:
return self._branch_false
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return (
self._cond,
self._branch_true,
) + ((self.branch_false,) if self.branch_false is not None else ())
def get_code(self, cstyle: CodeStyle) -> str:
code = f"if({self.condition.get_code(cstyle)}) {{\n"
code += cstyle.indent(self.branch_true.get_code(cstyle))
code += "\n}"
if self.branch_false is not None:
code += "else {\n"
code += cstyle.indent(self.branch_false.get_code(cstyle))
code += "\n}"
return code
class SfgSwitchCase(SfgCallTreeNode):
DefaultCaseType = NewType("DefaultCaseType", object)
"""Sentinel type representing the ``default`` case."""
Default = DefaultCaseType(object())
def __init__(self, label: str | SfgSwitchCase.DefaultCaseType, body: SfgSequence):
super().__init__()
self._label = label
self._body = body
@property
def label(self) -> str | DefaultCaseType:
return self._label
@property
def body(self) -> SfgSequence:
return self._body
@property
def children(self) -> Sequence[SfgCallTreeNode]:
return (self._body,)
@property
def is_default(self) -> bool:
return self._label == SfgSwitchCase.Default
def get_code(self, cstyle: CodeStyle) -> str:
code = ""
if self._label == SfgSwitchCase.Default:
code += "default: {\n"
else:
code += f"case {self._label}: {{\n"
code += cstyle.indent(self.body.get_code(cstyle))
code += "\n}"
return code
class SfgSwitch(SfgCallTreeNode):
def __init__(
self,
switch_arg: SfgStatements,
cases_dict: dict[str, SfgSequence],
default: SfgSequence | None = None,
):
super().__init__()
self._cases = [SfgSwitchCase(label, body) for label, body in cases_dict.items()]
if default is not None:
# invariant: the default case is always the last child
self._cases += [SfgSwitchCase(SfgSwitchCase.Default, default)]
self._switch_arg = switch_arg
self._default = (
SfgSwitchCase(SfgSwitchCase.Default, default)
if default is not None
else None
)
@property
def switch_arg(self) -> str | SfgStatements:
return self._switch_arg
@property
def default(self) -> SfgCallTreeNode | None:
return self._default
@property
def children(self) -> tuple[SfgCallTreeNode, ...]:
return (self._switch_arg,) + tuple(self._cases)
@property
def cases(self) -> tuple[SfgCallTreeNode, ...]:
if self._default is not None:
return tuple(self._cases[:-1])
else:
return tuple(self._cases)
@cases.setter
def cases(self, cs: Sequence[SfgSwitchCase]) -> None:
if len(cs) != len(self._cases):
raise ValueError("The number of child nodes must remain the same!")
self._default = None
for i, c in enumerate(cs):
if c.is_default:
if i != len(cs) - 1:
raise ValueError("Default case must be listed last.")
else:
self._default = c
self._children = list(cs)
def set_case(self, idx: int, c: SfgSwitchCase):
if c.is_default:
if idx != len(self._children) - 1:
raise ValueError("Default case must be the last child.")
elif self._default is None:
raise ValueError("Cannot replace normal case with default case.")
else:
self._default = c
self._children[-1] = c
else:
self._children[idx] = c
def get_code(self, cstyle: CodeStyle) -> str:
code = f"switch({self._switch_arg.get_code(cstyle)}) {{\n"
code += "\n".join(c.get_code(cstyle) for c in self._cases)
code += "}"
return code