Coverage for src/pystencilssfg/ir/postprocessing.py: 94%
159 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
2from typing import Sequence, Iterable
3import warnings
4from dataclasses import dataclass
6from abc import ABC, abstractmethod
8import sympy as sp
10from pystencils import Field
11from pystencils.types import deconstify, PsType
12from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride
14from ..exceptions import SfgException
15from ..config import CodeStyle
17from .call_tree import SfgCallTreeNode, SfgSequence, SfgStatements
18from ..lang.expressions import SfgKernelParamVar
19from ..lang import (
20 SfgVar,
21 SupportsFieldExtraction,
22 SupportsVectorExtraction,
23 ExprLike,
24 AugExpr,
25 depends,
26 includes,
27)
30class PostProcessingContext:
31 def __init__(self) -> None:
32 self._live_variables: dict[str, SfgVar] = dict()
34 @property
35 def live_variables(self) -> set[SfgVar]:
36 return set(self._live_variables.values())
38 def get_live_variable(self, name: str) -> SfgVar | None:
39 return self._live_variables.get(name)
41 def _define(self, vars: Iterable[SfgVar], expr: str):
42 for var in vars:
43 if var.name in self._live_variables:
44 live_var = self._live_variables[var.name]
46 live_var_dtype = live_var.dtype
47 def_dtype = var.dtype
49 # A const definition conflicts with a non-const live variable
50 # A non-const definition is always OK, but then the types must be the same
51 if (def_dtype.const and not live_var_dtype.const) or (
52 deconstify(def_dtype) != deconstify(live_var_dtype)
53 ):
54 warnings.warn(
55 f"Type conflict at variable definition: Expected type {live_var_dtype}, but got {def_dtype}.\n"
56 f" * At definition {expr}",
57 UserWarning,
58 )
60 del self._live_variables[var.name]
62 def _use(self, vars: Iterable[SfgVar]):
63 for var in vars:
64 if var.name in self._live_variables:
65 live_var = self._live_variables[var.name]
67 if var != live_var:
68 if var.dtype == live_var.dtype:
69 # This can only happen if the variables are SymbolLike,
70 # i.e. wrap a field-associated kernel parameter
71 # TODO: Once symbol properties are a thing, check and combine them here
72 warnings.warn(
73 "Encountered two non-identical variables with same name and data type:\n"
74 f" {var.name_and_type()}\n"
75 "and\n"
76 f" {live_var.name_and_type()}\n"
77 )
78 elif deconstify(var.dtype) == deconstify(live_var.dtype):
79 # Same type, just different constness
80 # One of them must be non-const -> keep the non-const one
81 if live_var.dtype.const and not var.dtype.const:
82 self._live_variables[var.name] = var
83 else:
84 raise SfgException(
85 "Encountered two variables with same name but different data types:\n"
86 f" {var.name_and_type()}\n"
87 "and\n"
88 f" {live_var.name_and_type()}"
89 )
90 else:
91 self._live_variables[var.name] = var
94@dataclass(frozen=True)
95class PostProcessingResult:
96 function_params: set[SfgVar]
99class CallTreePostProcessing:
100 def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
101 live_vars = self.get_live_variables(ast)
102 return PostProcessingResult(live_vars)
104 def handle_sequence(self, seq: SfgSequence, ppc: PostProcessingContext):
105 def iter_nested_sequences(seq: SfgSequence):
106 for i in range(len(seq.children) - 1, -1, -1):
107 c = seq.children[i]
109 if isinstance(c, SfgDeferredNode):
110 c = c.expand(ppc)
111 seq[i] = c
113 if isinstance(c, SfgSequence):
114 iter_nested_sequences(c)
115 else:
116 if isinstance(c, SfgStatements):
117 ppc._define(c.defines, c.code_string)
119 ppc._use(self.get_live_variables(c))
121 iter_nested_sequences(seq)
123 def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]:
124 match node:
125 case SfgSequence():
126 ppc = PostProcessingContext()
127 self.handle_sequence(node, ppc)
128 return ppc.live_variables
130 case SfgDeferredNode():
131 raise SfgException("Deferred nodes can only occur inside a sequence.")
133 case _:
134 return node.depends.union(
135 *(self.get_live_variables(c) for c in node.children)
136 )
139class SfgDeferredNode(SfgCallTreeNode, ABC):
140 """Nodes of this type are inserted as placeholders into the kernel call tree
141 and need to be expanded at a later time.
143 Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet
144 because information required for their construction is not yet known.
145 """
147 @property
148 def children(self) -> Sequence[SfgCallTreeNode]:
149 raise SfgException(
150 "Invalid access into deferred node; deferred nodes must be expanded first."
151 )
153 @abstractmethod
154 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
155 pass
157 def get_code(self, cstyle: CodeStyle) -> str:
158 raise SfgException(
159 "Invalid access into deferred node; deferred nodes must be expanded first."
160 )
163class SfgDeferredParamSetter(SfgDeferredNode):
164 def __init__(self, param: SfgVar | sp.Symbol, rhs: ExprLike):
165 self._lhs = param
166 self._rhs = rhs
168 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
169 live_var = ppc.get_live_variable(self._lhs.name)
170 if live_var is not None:
171 code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs};"
172 return SfgStatements(
173 code, (live_var,), depends(self._rhs), includes(self._rhs)
174 )
175 else:
176 return SfgSequence([])
179class SfgDeferredFieldMapping(SfgDeferredNode):
180 """Deferred mapping of a pystencils field to a field data structure."""
182 # NOTE ON Scalar Fields
183 #
184 # pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
185 # scalar fields. In order to handle both equivalently,
186 # we ignore the trivial explicit scalar dimension in field extraction.
187 # This makes sure that explicit D-dimensional scalar fields
188 # can be mapped onto D-dimensional data structures, and do not require that
189 # D+1st dimension.
191 def __init__(
192 self,
193 psfield: Field,
194 extraction: SupportsFieldExtraction,
195 cast_indexing_symbols: bool = True,
196 ):
197 self._field = psfield
198 self._extraction = extraction
199 self._cast_indexing_symbols = cast_indexing_symbols
201 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
202 # Find field pointer
203 ptr: SfgKernelParamVar | None = None
204 rank: int
206 if self._field.index_shape == (1,):
207 # explicit scalar field -> ignore index dimensions
208 rank = self._field.spatial_dimensions
209 else:
210 rank = len(self._field.shape)
212 shape: list[SfgKernelParamVar | str | None] = [None] * rank
213 strides: list[SfgKernelParamVar | str | None] = [None] * rank
215 for param in ppc.live_variables:
216 if isinstance(param, SfgKernelParamVar):
217 for prop in param.wrapped.properties:
218 match prop:
219 case FieldBasePtr(field) if field == self._field:
220 ptr = param
221 case FieldShape(field, coord) if field == self._field: # type: ignore
222 shape[coord] = param # type: ignore
223 case FieldStride(field, coord) if field == self._field: # type: ignore
224 strides[coord] = param # type: ignore
226 # Find constant or otherwise determined sizes
227 for coord, s in enumerate(self._field.shape[:rank]):
228 if shape[coord] is None:
229 shape[coord] = str(s)
231 # Find constant or otherwise determined strides
232 for coord, s in enumerate(self._field.strides[:rank]):
233 if strides[coord] is None:
234 strides[coord] = str(s)
236 # Now we have all the symbols, start extracting them
237 nodes = []
238 done: set[SfgKernelParamVar] = set()
240 if ptr is not None:
241 expr = self._extraction._extract_ptr()
242 nodes.append(
243 SfgStatements(
244 f"{ptr.dtype.c_string()} {ptr.name} { {expr} } ;",
245 (ptr,),
246 depends(expr),
247 includes(expr),
248 )
249 )
251 def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr:
252 if self._cast_indexing_symbols:
253 return AugExpr(target_type).bind(
254 "{}( {} )", deconstify(target_type).c_string(), expr
255 )
256 else:
257 return expr
259 def get_shape(coord, symb: SfgKernelParamVar | str):
260 expr = self._extraction._extract_size(coord)
262 if expr is None:
263 raise SfgException(
264 f"Cannot extract shape in coordinate {coord} from {self._extraction}"
265 )
267 if isinstance(symb, SfgKernelParamVar) and symb not in done:
268 done.add(symb)
269 expr = maybe_cast(expr, symb.dtype)
270 return SfgStatements(
271 f"{symb.dtype.c_string()} {symb.name} { {expr} } ;",
272 (symb,),
273 depends(expr),
274 includes(expr),
275 )
276 else:
277 return SfgStatements(f"/* {expr} == {symb} */", (), ())
279 def get_stride(coord, symb: SfgKernelParamVar | str):
280 expr = self._extraction._extract_stride(coord)
282 if expr is None:
283 raise SfgException(
284 f"Cannot extract stride in coordinate {coord} from {self._extraction}"
285 )
287 if isinstance(symb, SfgKernelParamVar) and symb not in done:
288 done.add(symb)
289 expr = maybe_cast(expr, symb.dtype)
290 return SfgStatements(
291 f"{symb.dtype.c_string()} {symb.name} { {expr} } ;",
292 (symb,),
293 depends(expr),
294 includes(expr),
295 )
296 else:
297 return SfgStatements(f"/* {expr} == {symb} */", (), ())
299 nodes += [get_shape(c, s) for c, s in enumerate(shape) if s is not None]
300 nodes += [get_stride(c, s) for c, s in enumerate(strides) if s is not None]
302 return SfgSequence(nodes)
305class SfgDeferredVectorMapping(SfgDeferredNode):
306 def __init__(
307 self,
308 scalars: Sequence[sp.Symbol | SfgVar],
309 vector: SupportsVectorExtraction,
310 ):
311 self._scalars = {sc.name: (i, sc) for i, sc in enumerate(scalars)}
312 self._vector = vector
314 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
315 nodes = []
317 for param in ppc.live_variables:
318 if param.name in self._scalars:
319 idx, _ = self._scalars[param.name]
320 expr = self._vector._extract_component(idx)
321 nodes.append(
322 SfgStatements(
323 f"{param.dtype.c_string()} {param.name} { {expr} } ;",
324 (param,),
325 depends(expr),
326 includes(expr),
327 )
328 )
330 return SfgSequence(nodes)