from __future__ import annotations from typing import TYPE_CHECKING, Sequence, Iterable, NewType from abc import ABC, abstractmethod from .entities import SfgKernelHandle from ..lang import SfgVar, HeaderFile if TYPE_CHECKING: from ..config import CodeStyle class SfgCallTreeNode(ABC): """Base class for all nodes comprising SFG call trees. ## Code Printing For extensibility, code printing is implemented inside the call tree. Therefore, every instantiable call tree node must implement the method `get_code`. By convention, the string returned by `get_code` should not contain a trailing newline. """ def __init__(self) -> None: self._includes: set[HeaderFile] = set() @property @abstractmethod def children(self) -> Sequence[SfgCallTreeNode]: """This node's children""" @abstractmethod def get_code(self, cstyle: CodeStyle) -> str: """Returns the code of this node. By convention, the code block emitted by this function should not contain a trailing newline. """ @property def depends(self) -> set[SfgVar]: """Set of objects this leaf depends on""" return set() @property def required_includes(self) -> set[HeaderFile]: """Return a set of header includes required by this node""" return self._includes class SfgCallTreeLeaf(SfgCallTreeNode, ABC): """A leaf node of the call tree. Leaf nodes must implement ``depends`` for automatic parameter collection. """ def __init__(self): super().__init__() @property def children(self) -> Sequence[SfgCallTreeNode]: return () class SfgEmptyNode(SfgCallTreeLeaf): """A leaf node that does not emit any code. Empty nodes must still implement ``depends``. """ def __init__(self): super().__init__() def get_code(self, cstyle: CodeStyle) -> str: return "" class SfgStatements(SfgCallTreeLeaf): """Represents (a sequence of) statements in the source language. This class groups together arbitrary code strings (e.g. sequences of C++ statements, cf. https://en.cppreference.com/w/cpp/language/statements), and annotates them with the set of symbols read and written by these statements. It is the user's responsibility to ensure that the code string is valid code in the output language, and that the lists of required and defined objects are correct and complete. Args: code_string: Code to be printed out. defined_params: Variables that will be newly defined and visible to code in sequence after these statements. required_params: Variables that are required as input to these statements. """ def __init__( self, code_string: str, defines: Iterable[SfgVar], depends: Iterable[SfgVar], includes: Iterable[HeaderFile] = (), ): super().__init__() self._code_string = code_string self._defines = set(defines) self._depends = set(depends) self._includes = set(includes) @property def depends(self) -> set[SfgVar]: return self._depends @property def defines(self) -> set[SfgVar]: return self._defines @property def code_string(self) -> str: return self._code_string def get_code(self, cstyle: CodeStyle) -> str: return self._code_string class SfgFunctionParams(SfgEmptyNode): def __init__(self, parameters: Sequence[SfgVar]): super().__init__() self._params = set(parameters) @property def depends(self) -> set[SfgVar]: return self._params class SfgRequireIncludes(SfgEmptyNode): def __init__(self, includes: Iterable[HeaderFile]): super().__init__() self._includes = set(includes) @property def depends(self) -> set[SfgVar]: return set() class SfgSequence(SfgCallTreeNode): __match_args__ = ("children",) def __init__(self, children: Sequence[SfgCallTreeNode]): super().__init__() self._children = list(children) @property def children(self) -> Sequence[SfgCallTreeNode]: return self._children @children.setter def children(self, cs: Sequence[SfgCallTreeNode]): self._children = list(cs) def __getitem__(self, idx: int) -> SfgCallTreeNode: return self._children[idx] def __setitem__(self, idx: int, c: SfgCallTreeNode): self._children[idx] = c def get_code(self, cstyle: CodeStyle) -> str: return "\n".join(c.get_code(cstyle) for c in self._children) class SfgBlock(SfgCallTreeNode): def __init__(self, seq: SfgSequence): super().__init__() self._seq = seq @property def sequence(self) -> SfgSequence: return self._seq @property def children(self) -> Sequence[SfgCallTreeNode]: return (self._seq,) def get_code(self, cstyle: CodeStyle) -> str: seq_code = cstyle.indent(self._seq.get_code(cstyle)) return "{\n" + seq_code + "\n}" # class SfgForLoop(SfgCallTreeNode): # def __init__(self, control_line: SfgStatements, body: SfgCallTreeNode): # super().__init__(control_line, body) # @property # def body(self) -> SfgStatements: # return cast(SfgStatements) class SfgKernelCallNode(SfgCallTreeLeaf): def __init__(self, kernel_handle: SfgKernelHandle): super().__init__() self._kernel_handle = kernel_handle @property def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) def get_code(self, cstyle: CodeStyle) -> str: kparams = self._kernel_handle.parameters fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in kparams]) return f"{fnc_name}({call_parameters});" class SfgGpuKernelInvocation(SfgCallTreeNode): """A CUDA or HIP kernel invocation. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#execution-configuration or https://rocmdocs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#calling-global-functions for the syntax. """ def __init__( self, kernel_handle: SfgKernelHandle, grid_size: SfgStatements, block_size: SfgStatements, shared_memory_bytes: SfgStatements | None, stream: SfgStatements | None, ): from pystencils.codegen import GpuKernel kernel = kernel_handle.kernel if not isinstance(kernel, GpuKernel): raise ValueError( "An `SfgGpuKernelInvocation` node can only call GPU kernels." ) super().__init__() self._kernel_handle = kernel_handle self._grid_size = grid_size self._block_size = block_size self._shared_memory_bytes = shared_memory_bytes self._stream = stream @property def children(self) -> Sequence[SfgCallTreeNode]: return ( ( self._grid_size, self._block_size, ) + ( (self._shared_memory_bytes,) if self._shared_memory_bytes is not None else () ) + ((self._stream,) if self._stream is not None else ()) ) @property def depends(self) -> set[SfgVar]: return set(self._kernel_handle.parameters) def get_code(self, cstyle: CodeStyle) -> str: kparams = self._kernel_handle.parameters fnc_name = self._kernel_handle.fqname call_parameters = ", ".join([p.name for p in kparams]) grid_args = [self._grid_size, self._block_size] if self._shared_memory_bytes is not None: grid_args += [self._shared_memory_bytes] if self._stream is not None: grid_args += [self._stream] grid = "<<< " + ", ".join(arg.get_code(cstyle) for arg in grid_args) + " >>>" return f"{fnc_name}{grid}({call_parameters});" class SfgBranch(SfgCallTreeNode): def __init__( self, cond: SfgStatements, branch_true: SfgSequence, branch_false: SfgSequence | None = None, ): super().__init__() self._cond = cond self._branch_true = branch_true self._branch_false = branch_false @property def condition(self) -> SfgStatements: return self._cond @property def branch_true(self) -> SfgSequence: return self._branch_true @property def branch_false(self) -> SfgSequence | None: return self._branch_false @property def children(self) -> Sequence[SfgCallTreeNode]: return ( self._cond, self._branch_true, ) + ((self.branch_false,) if self.branch_false is not None else ()) def get_code(self, cstyle: CodeStyle) -> str: code = f"if({self.condition.get_code(cstyle)}) {{\n" code += cstyle.indent(self.branch_true.get_code(cstyle)) code += "\n}" if self.branch_false is not None: code += "else {\n" code += cstyle.indent(self.branch_false.get_code(cstyle)) code += "\n}" return code class SfgSwitchCase(SfgCallTreeNode): DefaultCaseType = NewType("DefaultCaseType", object) """Sentinel type representing the ``default`` case.""" Default = DefaultCaseType(object()) def __init__(self, label: str | SfgSwitchCase.DefaultCaseType, body: SfgSequence): super().__init__() self._label = label self._body = body @property def label(self) -> str | DefaultCaseType: return self._label @property def body(self) -> SfgSequence: return self._body @property def children(self) -> Sequence[SfgCallTreeNode]: return (self._body,) @property def is_default(self) -> bool: return self._label == SfgSwitchCase.Default def get_code(self, cstyle: CodeStyle) -> str: code = "" if self._label == SfgSwitchCase.Default: code += "default: {\n" else: code += f"case {self._label}: {{\n" code += cstyle.indent(self.body.get_code(cstyle)) code += "\n}" return code class SfgSwitch(SfgCallTreeNode): def __init__( self, switch_arg: SfgStatements, cases_dict: dict[str, SfgSequence], default: SfgSequence | None = None, ): super().__init__() self._cases = [SfgSwitchCase(label, body) for label, body in cases_dict.items()] if default is not None: # invariant: the default case is always the last child self._cases += [SfgSwitchCase(SfgSwitchCase.Default, default)] self._switch_arg = switch_arg self._default = ( SfgSwitchCase(SfgSwitchCase.Default, default) if default is not None else None ) @property def switch_arg(self) -> str | SfgStatements: return self._switch_arg @property def default(self) -> SfgCallTreeNode | None: return self._default @property def children(self) -> tuple[SfgCallTreeNode, ...]: return (self._switch_arg,) + tuple(self._cases) @property def cases(self) -> tuple[SfgCallTreeNode, ...]: if self._default is not None: return tuple(self._cases[:-1]) else: return tuple(self._cases) @cases.setter def cases(self, cs: Sequence[SfgSwitchCase]) -> None: if len(cs) != len(self._cases): raise ValueError("The number of child nodes must remain the same!") self._default = None for i, c in enumerate(cs): if c.is_default: if i != len(cs) - 1: raise ValueError("Default case must be listed last.") else: self._default = c self._children = list(cs) def set_case(self, idx: int, c: SfgSwitchCase): if c.is_default: if idx != len(self._children) - 1: raise ValueError("Default case must be the last child.") elif self._default is None: raise ValueError("Cannot replace normal case with default case.") else: self._default = c self._children[-1] = c else: self._children[idx] = c def get_code(self, cstyle: CodeStyle) -> str: code = f"switch({self._switch_arg.get_code(cstyle)}) {{\n" code += "\n".join(c.get_code(cstyle) for c in self._cases) code += "}" return code