Coverage for src/pystencilssfg/ir/postprocessing.py: 94%

159 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2from typing import Sequence, Iterable 

3import warnings 

4from dataclasses import dataclass 

5 

6from abc import ABC, abstractmethod 

7 

8import sympy as sp 

9 

10from pystencils import Field 

11from pystencils.types import deconstify, PsType 

12from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride 

13 

14from ..exceptions import SfgException 

15from ..config import CodeStyle 

16 

17from .call_tree import SfgCallTreeNode, SfgSequence, SfgStatements 

18from ..lang.expressions import SfgKernelParamVar 

19from ..lang import ( 

20 SfgVar, 

21 SupportsFieldExtraction, 

22 SupportsVectorExtraction, 

23 ExprLike, 

24 AugExpr, 

25 depends, 

26 includes, 

27) 

28 

29 

30class PostProcessingContext: 

31 def __init__(self) -> None: 

32 self._live_variables: dict[str, SfgVar] = dict() 

33 

34 @property 

35 def live_variables(self) -> set[SfgVar]: 

36 return set(self._live_variables.values()) 

37 

38 def get_live_variable(self, name: str) -> SfgVar | None: 

39 return self._live_variables.get(name) 

40 

41 def _define(self, vars: Iterable[SfgVar], expr: str): 

42 for var in vars: 

43 if var.name in self._live_variables: 

44 live_var = self._live_variables[var.name] 

45 

46 live_var_dtype = live_var.dtype 

47 def_dtype = var.dtype 

48 

49 # A const definition conflicts with a non-const live variable 

50 # A non-const definition is always OK, but then the types must be the same 

51 if (def_dtype.const and not live_var_dtype.const) or ( 

52 deconstify(def_dtype) != deconstify(live_var_dtype) 

53 ): 

54 warnings.warn( 

55 f"Type conflict at variable definition: Expected type {live_var_dtype}, but got {def_dtype}.\n" 

56 f" * At definition {expr}", 

57 UserWarning, 

58 ) 

59 

60 del self._live_variables[var.name] 

61 

62 def _use(self, vars: Iterable[SfgVar]): 

63 for var in vars: 

64 if var.name in self._live_variables: 

65 live_var = self._live_variables[var.name] 

66 

67 if var != live_var: 

68 if var.dtype == live_var.dtype: 

69 # This can only happen if the variables are SymbolLike, 

70 # i.e. wrap a field-associated kernel parameter 

71 # TODO: Once symbol properties are a thing, check and combine them here 

72 warnings.warn( 

73 "Encountered two non-identical variables with same name and data type:\n" 

74 f" {var.name_and_type()}\n" 

75 "and\n" 

76 f" {live_var.name_and_type()}\n" 

77 ) 

78 elif deconstify(var.dtype) == deconstify(live_var.dtype): 

79 # Same type, just different constness 

80 # One of them must be non-const -> keep the non-const one 

81 if live_var.dtype.const and not var.dtype.const: 

82 self._live_variables[var.name] = var 

83 else: 

84 raise SfgException( 

85 "Encountered two variables with same name but different data types:\n" 

86 f" {var.name_and_type()}\n" 

87 "and\n" 

88 f" {live_var.name_and_type()}" 

89 ) 

90 else: 

91 self._live_variables[var.name] = var 

92 

93 

94@dataclass(frozen=True) 

95class PostProcessingResult: 

96 function_params: set[SfgVar] 

97 

98 

99class CallTreePostProcessing: 

100 def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: 

101 live_vars = self.get_live_variables(ast) 

102 return PostProcessingResult(live_vars) 

103 

104 def handle_sequence(self, seq: SfgSequence, ppc: PostProcessingContext): 

105 def iter_nested_sequences(seq: SfgSequence): 

106 for i in range(len(seq.children) - 1, -1, -1): 

107 c = seq.children[i] 

108 

109 if isinstance(c, SfgDeferredNode): 

110 c = c.expand(ppc) 

111 seq[i] = c 

112 

113 if isinstance(c, SfgSequence): 

114 iter_nested_sequences(c) 

115 else: 

116 if isinstance(c, SfgStatements): 

117 ppc._define(c.defines, c.code_string) 

118 

119 ppc._use(self.get_live_variables(c)) 

120 

121 iter_nested_sequences(seq) 

122 

123 def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]: 

124 match node: 

125 case SfgSequence(): 

126 ppc = PostProcessingContext() 

127 self.handle_sequence(node, ppc) 

128 return ppc.live_variables 

129 

130 case SfgDeferredNode(): 

131 raise SfgException("Deferred nodes can only occur inside a sequence.") 

132 

133 case _: 

134 return node.depends.union( 

135 *(self.get_live_variables(c) for c in node.children) 

136 ) 

137 

138 

139class SfgDeferredNode(SfgCallTreeNode, ABC): 

140 """Nodes of this type are inserted as placeholders into the kernel call tree 

141 and need to be expanded at a later time. 

142 

143 Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet 

144 because information required for their construction is not yet known. 

145 """ 

146 

147 @property 

148 def children(self) -> Sequence[SfgCallTreeNode]: 

149 raise SfgException( 

150 "Invalid access into deferred node; deferred nodes must be expanded first." 

151 ) 

152 

153 @abstractmethod 

154 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: 

155 pass 

156 

157 def get_code(self, cstyle: CodeStyle) -> str: 

158 raise SfgException( 

159 "Invalid access into deferred node; deferred nodes must be expanded first." 

160 ) 

161 

162 

163class SfgDeferredParamSetter(SfgDeferredNode): 

164 def __init__(self, param: SfgVar | sp.Symbol, rhs: ExprLike): 

165 self._lhs = param 

166 self._rhs = rhs 

167 

168 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: 

169 live_var = ppc.get_live_variable(self._lhs.name) 

170 if live_var is not None: 

171 code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs};" 

172 return SfgStatements( 

173 code, (live_var,), depends(self._rhs), includes(self._rhs) 

174 ) 

175 else: 

176 return SfgSequence([]) 

177 

178 

179class SfgDeferredFieldMapping(SfgDeferredNode): 

180 """Deferred mapping of a pystencils field to a field data structure.""" 

181 

182 # NOTE ON Scalar Fields 

183 # 

184 # pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`) 

185 # scalar fields. In order to handle both equivalently, 

186 # we ignore the trivial explicit scalar dimension in field extraction. 

187 # This makes sure that explicit D-dimensional scalar fields 

188 # can be mapped onto D-dimensional data structures, and do not require that 

189 # D+1st dimension. 

190 

191 def __init__( 

192 self, 

193 psfield: Field, 

194 extraction: SupportsFieldExtraction, 

195 cast_indexing_symbols: bool = True, 

196 ): 

197 self._field = psfield 

198 self._extraction = extraction 

199 self._cast_indexing_symbols = cast_indexing_symbols 

200 

201 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: 

202 # Find field pointer 

203 ptr: SfgKernelParamVar | None = None 

204 rank: int 

205 

206 if self._field.index_shape == (1,): 

207 # explicit scalar field -> ignore index dimensions 

208 rank = self._field.spatial_dimensions 

209 else: 

210 rank = len(self._field.shape) 

211 

212 shape: list[SfgKernelParamVar | str | None] = [None] * rank 

213 strides: list[SfgKernelParamVar | str | None] = [None] * rank 

214 

215 for param in ppc.live_variables: 

216 if isinstance(param, SfgKernelParamVar): 

217 for prop in param.wrapped.properties: 

218 match prop: 

219 case FieldBasePtr(field) if field == self._field: 

220 ptr = param 

221 case FieldShape(field, coord) if field == self._field: # type: ignore 

222 shape[coord] = param # type: ignore 

223 case FieldStride(field, coord) if field == self._field: # type: ignore 

224 strides[coord] = param # type: ignore 

225 

226 # Find constant or otherwise determined sizes 

227 for coord, s in enumerate(self._field.shape[:rank]): 

228 if shape[coord] is None: 

229 shape[coord] = str(s) 

230 

231 # Find constant or otherwise determined strides 

232 for coord, s in enumerate(self._field.strides[:rank]): 

233 if strides[coord] is None: 

234 strides[coord] = str(s) 

235 

236 # Now we have all the symbols, start extracting them 

237 nodes = [] 

238 done: set[SfgKernelParamVar] = set() 

239 

240 if ptr is not None: 

241 expr = self._extraction._extract_ptr() 

242 nodes.append( 

243 SfgStatements( 

244 f"{ptr.dtype.c_string()} {ptr.name} { {expr} } ;", 

245 (ptr,), 

246 depends(expr), 

247 includes(expr), 

248 ) 

249 ) 

250 

251 def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr: 

252 if self._cast_indexing_symbols: 

253 return AugExpr(target_type).bind( 

254 "{}( {} )", deconstify(target_type).c_string(), expr 

255 ) 

256 else: 

257 return expr 

258 

259 def get_shape(coord, symb: SfgKernelParamVar | str): 

260 expr = self._extraction._extract_size(coord) 

261 

262 if expr is None: 

263 raise SfgException( 

264 f"Cannot extract shape in coordinate {coord} from {self._extraction}" 

265 ) 

266 

267 if isinstance(symb, SfgKernelParamVar) and symb not in done: 

268 done.add(symb) 

269 expr = maybe_cast(expr, symb.dtype) 

270 return SfgStatements( 

271 f"{symb.dtype.c_string()} {symb.name} { {expr} } ;", 

272 (symb,), 

273 depends(expr), 

274 includes(expr), 

275 ) 

276 else: 

277 return SfgStatements(f"/* {expr} == {symb} */", (), ()) 

278 

279 def get_stride(coord, symb: SfgKernelParamVar | str): 

280 expr = self._extraction._extract_stride(coord) 

281 

282 if expr is None: 

283 raise SfgException( 

284 f"Cannot extract stride in coordinate {coord} from {self._extraction}" 

285 ) 

286 

287 if isinstance(symb, SfgKernelParamVar) and symb not in done: 

288 done.add(symb) 

289 expr = maybe_cast(expr, symb.dtype) 

290 return SfgStatements( 

291 f"{symb.dtype.c_string()} {symb.name} { {expr} } ;", 

292 (symb,), 

293 depends(expr), 

294 includes(expr), 

295 ) 

296 else: 

297 return SfgStatements(f"/* {expr} == {symb} */", (), ()) 

298 

299 nodes += [get_shape(c, s) for c, s in enumerate(shape) if s is not None] 

300 nodes += [get_stride(c, s) for c, s in enumerate(strides) if s is not None] 

301 

302 return SfgSequence(nodes) 

303 

304 

305class SfgDeferredVectorMapping(SfgDeferredNode): 

306 def __init__( 

307 self, 

308 scalars: Sequence[sp.Symbol | SfgVar], 

309 vector: SupportsVectorExtraction, 

310 ): 

311 self._scalars = {sc.name: (i, sc) for i, sc in enumerate(scalars)} 

312 self._vector = vector 

313 

314 def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: 

315 nodes = [] 

316 

317 for param in ppc.live_variables: 

318 if param.name in self._scalars: 

319 idx, _ = self._scalars[param.name] 

320 expr = self._vector._extract_component(idx) 

321 nodes.append( 

322 SfgStatements( 

323 f"{param.dtype.c_string()} {param.name} { {expr} } ;", 

324 (param,), 

325 depends(expr), 

326 includes(expr), 

327 ) 

328 ) 

329 

330 return SfgSequence(nodes)