from __future__ import annotations from typing import Sequence, Iterable import warnings from dataclasses import dataclass from abc import ABC, abstractmethod import sympy as sp from pystencils import Field from pystencils.types import deconstify, PsType from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException from ..config import CodeStyle from .call_tree import SfgCallTreeNode, SfgSequence, SfgStatements from ..lang.expressions import SfgKernelParamVar from ..lang import ( SfgVar, SupportsFieldExtraction, SupportsVectorExtraction, ExprLike, AugExpr, depends, includes, ) class FlattenSequences: """Flattens any nested sequences occuring in a kernel call tree.""" def __call__(self, node: SfgCallTreeNode) -> None: self.visit(node) def visit(self, node: SfgCallTreeNode): match node: case SfgSequence(): self.flatten(node) case _: for c in node.children: self.visit(c) def flatten(self, sequence: SfgSequence) -> None: children_flattened: list[SfgCallTreeNode] = [] def flatten(seq: SfgSequence): for c in seq.children: if isinstance(c, SfgSequence): flatten(c) else: children_flattened.append(c) flatten(sequence) for c in children_flattened: self.visit(c) sequence.children = children_flattened class PostProcessingContext: def __init__(self) -> None: self._live_variables: dict[str, SfgVar] = dict() @property def live_variables(self) -> set[SfgVar]: return set(self._live_variables.values()) def get_live_variable(self, name: str) -> SfgVar | None: return self._live_variables.get(name) def _define(self, vars: Iterable[SfgVar], expr: str): for var in vars: if var.name in self._live_variables: live_var = self._live_variables[var.name] live_var_dtype = live_var.dtype def_dtype = var.dtype # A const definition conflicts with a non-const live variable # A non-const definition is always OK, but then the types must be the same if (def_dtype.const and not live_var_dtype.const) or ( deconstify(def_dtype) != deconstify(live_var_dtype) ): warnings.warn( f"Type conflict at variable definition: Expected type {live_var_dtype}, but got {def_dtype}.\n" f" * At definition {expr}", UserWarning, ) del self._live_variables[var.name] def _use(self, vars: Iterable[SfgVar]): for var in vars: if var.name in self._live_variables: live_var = self._live_variables[var.name] if var != live_var: if var.dtype == live_var.dtype: # This can only happen if the variables are SymbolLike, # i.e. wrap a field-associated kernel parameter # TODO: Once symbol properties are a thing, check and combine them here warnings.warn( "Encountered two non-identical variables with same name and data type:\n" f" {var.name_and_type()}\n" "and\n" f" {live_var.name_and_type()}\n" ) elif deconstify(var.dtype) == deconstify(live_var.dtype): # Same type, just different constness # One of them must be non-const -> keep the non-const one if live_var.dtype.const and not var.dtype.const: self._live_variables[var.name] = var else: raise SfgException( "Encountered two variables with same name but different data types:\n" f" {var.name_and_type()}\n" "and\n" f" {live_var.name_and_type()}" ) else: self._live_variables[var.name] = var @dataclass(frozen=True) class PostProcessingResult: function_params: set[SfgVar] class CallTreePostProcessing: def __init__(self): self._flattener = FlattenSequences() def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: live_vars = self.get_live_variables(ast) return PostProcessingResult(live_vars) def handle_sequence(self, seq: SfgSequence, ppc: PostProcessingContext): def iter_nested_sequences(seq: SfgSequence): for i in range(len(seq.children) - 1, -1, -1): c = seq.children[i] if isinstance(c, SfgDeferredNode): c = c.expand(ppc) seq[i] = c if isinstance(c, SfgSequence): iter_nested_sequences(c) else: if isinstance(c, SfgStatements): ppc._define(c.defines, c.code_string) ppc._use(self.get_live_variables(c)) iter_nested_sequences(seq) def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]: match node: case SfgSequence(): ppc = PostProcessingContext() self.handle_sequence(node, ppc) return ppc.live_variables case SfgDeferredNode(): raise SfgException("Deferred nodes can only occur inside a sequence.") case _: return node.depends.union( *(self.get_live_variables(c) for c in node.children) ) class SfgDeferredNode(SfgCallTreeNode, ABC): """Nodes of this type are inserted as placeholders into the kernel call tree and need to be expanded at a later time. Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet because information required for their construction is not yet known. """ @property def children(self) -> Sequence[SfgCallTreeNode]: raise SfgException( "Invalid access into deferred node; deferred nodes must be expanded first." ) @abstractmethod def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: pass def get_code(self, cstyle: CodeStyle) -> str: raise SfgException( "Invalid access into deferred node; deferred nodes must be expanded first." ) class SfgDeferredParamSetter(SfgDeferredNode): def __init__(self, param: SfgVar | sp.Symbol, rhs: ExprLike): self._lhs = param self._rhs = rhs def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs};" return SfgStatements( code, (live_var,), depends(self._rhs), includes(self._rhs) ) else: return SfgSequence([]) class SfgDeferredFieldMapping(SfgDeferredNode): """Deferred mapping of a pystencils field to a field data structure.""" def __init__( self, psfield: Field, extraction: SupportsFieldExtraction, cast_indexing_symbols: bool = True, ): self._field = psfield self._extraction = extraction self._cast_indexing_symbols = cast_indexing_symbols def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: # Find field pointer ptr: SfgKernelParamVar | None = None shape: list[SfgKernelParamVar | str | None] = [None] * len(self._field.shape) strides: list[SfgKernelParamVar | str | None] = [None] * len( self._field.strides ) index_dims_are_trivial = self._field.index_shape == (1,) rank = len(self._field.shape) for param in ppc.live_variables: if isinstance(param, SfgKernelParamVar): for prop in param.wrapped.properties: match prop: case FieldBasePtr(field) if field == self._field: ptr = param case FieldShape(field, coord) if field == self._field: # type: ignore shape[coord] = param # type: ignore case FieldStride(field, coord) if field == self._field: # type: ignore strides[coord] = param # type: ignore # Find constant or otherwise determined sizes for coord, s in enumerate(self._field.shape): if shape[coord] is None: shape[coord] = str(s) # Find constant or otherwise determined strides for coord, s in enumerate(self._field.strides): if strides[coord] is None: strides[coord] = str(s) # Now we have all the symbols, start extracting them nodes = [] done: set[SfgKernelParamVar] = set() if ptr is not None: expr = self._extraction._extract_ptr() nodes.append( SfgStatements( f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), depends(expr), includes(expr), ) ) def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr: if self._cast_indexing_symbols: return AugExpr(target_type).bind( "{}( {} )", deconstify(target_type).c_string(), expr ) else: return expr def get_shape(coord, symb: SfgKernelParamVar | str): expr = self._extraction._extract_size(coord) if expr is None: if index_dims_are_trivial and coord == rank - 1: expr = AugExpr.format("1") else: raise SfgException( f"Cannot extract shape in coordinate {coord} from {self._extraction}" ) if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) expr = maybe_cast(expr, symb.dtype) return SfgStatements( f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), depends(expr), includes(expr), ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) def get_stride(coord, symb: SfgKernelParamVar | str): expr = self._extraction._extract_stride(coord) if expr is None: if index_dims_are_trivial and coord == rank - 1: expr = AugExpr.format("1") else: raise SfgException( f"Cannot extract stride in coordinate {coord} from {self._extraction}" ) if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) expr = maybe_cast(expr, symb.dtype) return SfgStatements( f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), depends(expr), includes(expr), ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) nodes += [get_shape(c, s) for c, s in enumerate(shape) if s is not None] nodes += [get_stride(c, s) for c, s in enumerate(strides) if s is not None] return SfgSequence(nodes) class SfgDeferredVectorMapping(SfgDeferredNode): def __init__( self, scalars: Sequence[sp.Symbol | SfgVar], vector: SupportsVectorExtraction, ): self._scalars = {sc.name: (i, sc) for i, sc in enumerate(scalars)} self._vector = vector def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: nodes = [] for param in ppc.live_variables: if param.name in self._scalars: idx, _ = self._scalars[param.name] expr = self._vector._extract_component(idx) nodes.append( SfgStatements( f"{param.dtype.c_string()} {param.name} {{ {expr} }};", (param,), depends(expr), includes(expr), ) ) return SfgSequence(nodes)