Coverage for src/pystencilssfg/composer/class_composer.py: 82%

147 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2from typing import Sequence 

3from itertools import takewhile, dropwhile 

4import numpy as np 

5 

6from pystencils.types import create_type 

7 

8from ..context import SfgContext, SfgCursor 

9from ..lang import ( 

10 VarLike, 

11 ExprLike, 

12 asvar, 

13 SfgVar, 

14) 

15 

16from ..ir import ( 

17 SfgCallTreeNode, 

18 SfgClass, 

19 SfgConstructor, 

20 SfgMethod, 

21 SfgMemberVariable, 

22 SfgClassKeyword, 

23 SfgVisibility, 

24 SfgVisibilityBlock, 

25 SfgEntityDecl, 

26 SfgEntityDef, 

27 SfgClassBody, 

28) 

29from ..exceptions import SfgException 

30 

31from .mixin import SfgComposerMixIn 

32from .basic_composer import ( 

33 make_sequence, 

34 SequencerArg, 

35 SfgFunctionSequencerBase, 

36) 

37 

38 

39class SfgMethodSequencer(SfgFunctionSequencerBase): 

40 def __init__(self, cursor: SfgCursor, name: str) -> None: 

41 super().__init__(cursor, name) 

42 

43 self._const: bool = False 

44 self._static: bool = False 

45 self._virtual: bool = False 

46 self._override: bool = False 

47 

48 self._tree: SfgCallTreeNode 

49 

50 def const(self): 

51 """Mark this method as ``const``.""" 

52 self._const = True 

53 return self 

54 

55 def static(self): 

56 """Mark this method as ``static``.""" 

57 self._static = True 

58 return self 

59 

60 def virtual(self): 

61 """Mark this method as ``virtual``.""" 

62 self._virtual = True 

63 return self 

64 

65 def override(self): 

66 """Mark this method as ``override``.""" 

67 self._override = True 

68 return self 

69 

70 def __call__(self, *args: SequencerArg): 

71 self._tree = make_sequence(*args) 

72 return self 

73 

74 def _resolve(self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock): 

75 method = SfgMethod( 

76 self._name, 

77 cls, 

78 self._tree, 

79 return_type=self._return_type, 

80 inline=self._inline, 

81 const=self._const, 

82 static=self._static, 

83 constexpr=self._constexpr, 

84 virtual=self._virtual, 

85 override=self._override, 

86 attributes=self._attributes, 

87 required_params=self._params, 

88 ) 

89 cls.add_member(method, vis_block.visibility) 

90 

91 if self._inline: 

92 vis_block.elements.append(SfgEntityDef(method)) 

93 else: 

94 vis_block.elements.append(SfgEntityDecl(method)) 

95 ctx._cursor.write_impl(SfgEntityDef(method)) 

96 

97 

98class SfgClassComposer(SfgComposerMixIn): 

99 """Composer for classes and structs. 

100 

101 

102 This class cannot be instantiated on its own but must be mixed in with 

103 :class:`SfgBasicComposer`. 

104 Its interface is exposed by :class:`SfgComposer`. 

105 """ 

106 

107 class VisibilityBlockSequencer: 

108 """Represent a visibility block in the composer syntax. 

109 

110 Returned by `private`, `public`, and `protected`. 

111 """ 

112 

113 def __init__(self, visibility: SfgVisibility): 

114 self._visibility = visibility 

115 self._args: tuple[ 

116 SfgMethodSequencer 

117 | SfgClassComposer.ConstructorBuilder 

118 | VarLike 

119 | str, 

120 ..., 

121 ] 

122 

123 def __call__( 

124 self, 

125 *args: ( 

126 SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str 

127 ), 

128 ): 

129 self._args = args 

130 return self 

131 

132 def _resolve(self, ctx: SfgContext, cls: SfgClass) -> SfgVisibilityBlock: 

133 vis_block = SfgVisibilityBlock(self._visibility) 

134 for arg in self._args: 

135 match arg: 

136 case SfgMethodSequencer() | SfgClassComposer.ConstructorBuilder(): 

137 arg._resolve(ctx, cls, vis_block) 

138 case str(): 

139 vis_block.elements.append(arg) 

140 case _: 

141 var = asvar(arg) 

142 member_var = SfgMemberVariable(var.name, var.dtype, cls) 

143 cls.add_member(member_var, vis_block.visibility) 

144 vis_block.elements.append(SfgEntityDef(member_var)) 

145 return vis_block 

146 

147 class ConstructorBuilder: 

148 """Composer syntax for constructor building. 

149 

150 Returned by `constructor`. 

151 """ 

152 

153 def __init__(self, *params: VarLike): 

154 self._params = list(asvar(p) for p in params) 

155 self._initializers: list[tuple[SfgVar | str, tuple[ExprLike, ...]]] = [] 

156 self._body: str | None = None 

157 

158 def add_param(self, param: VarLike, at: int | None = None): 

159 if at is None: 

160 self._params.append(asvar(param)) 

161 else: 

162 self._params.insert(at, asvar(param)) 

163 

164 @property 

165 def parameters(self) -> list[SfgVar]: 

166 return self._params 

167 

168 def init(self, var: VarLike | str): 

169 """Add an initialization expression to the constructor's initializer list.""" 

170 

171 member = var if isinstance(var, str) else asvar(var) 

172 

173 def init_sequencer(*args: ExprLike): 

174 self._initializers.append((member, args)) 

175 return self 

176 

177 return init_sequencer 

178 

179 def body(self, body: str): 

180 """Define the constructor body""" 

181 if self._body is not None: 

182 raise SfgException("Multiple definitions of constructor body.") 

183 self._body = body 

184 return self 

185 

186 def _resolve( 

187 self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock 

188 ): 

189 ctor = SfgConstructor( 

190 cls, 

191 parameters=self._params, 

192 initializers=self._initializers, 

193 body=self._body if self._body is not None else "", 

194 ) 

195 

196 cls.add_member(ctor, vis_block.visibility) 

197 vis_block.elements.append(SfgEntityDef(ctor)) 

198 

199 def klass(self, class_name: str, bases: Sequence[str] = ()): 

200 """Create a class and add it to the underlying context. 

201 

202 Args: 

203 class_name: Name of the class 

204 bases: List of base classes 

205 """ 

206 return self._class(class_name, SfgClassKeyword.CLASS, bases) 

207 

208 def struct(self, class_name: str, bases: Sequence[str] = ()): 

209 """Create a struct and add it to the underlying context. 

210 

211 Args: 

212 class_name: Name of the struct 

213 bases: List of base classes 

214 """ 

215 return self._class(class_name, SfgClassKeyword.STRUCT, bases) 

216 

217 def numpy_struct(self, name: str, dtype: np.dtype, add_constructor: bool = True): 

218 """Add a numpy structured data type as a C++ struct 

219 

220 Returns: 

221 The created class object 

222 """ 

223 return self._struct_from_numpy_dtype(name, dtype, add_constructor) 

224 

225 @property 

226 def public(self) -> SfgClassComposer.VisibilityBlockSequencer: 

227 """Create a `public` visibility block in a class body""" 

228 return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PUBLIC) 

229 

230 @property 

231 def protected(self) -> SfgClassComposer.VisibilityBlockSequencer: 

232 """Create a `protected` visibility block in a class or struct body""" 

233 return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PROTECTED) 

234 

235 @property 

236 def private(self) -> SfgClassComposer.VisibilityBlockSequencer: 

237 """Create a `private` visibility block in a class or struct body""" 

238 return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PRIVATE) 

239 

240 def constructor(self, *params: VarLike): 

241 """In a class or struct body or visibility block, add a constructor. 

242 

243 Args: 

244 params: List of constructor parameters 

245 """ 

246 return SfgClassComposer.ConstructorBuilder(*params) 

247 

248 def method(self, name: str): 

249 """In a class or struct body or visibility block, add a method. 

250 The usage is similar to :any:`SfgBasicComposer.function`. 

251 

252 Args: 

253 name: The method name 

254 """ 

255 

256 seq = SfgMethodSequencer(self._cursor, name) 

257 if self._ctx.impl_file is None: 

258 seq.inline() 

259 return seq 

260 

261 # INTERNALS 

262 

263 def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]): 

264 # TODO: Return a `CppClass` instance representing the generated class 

265 

266 if self._cursor.get_entity(class_name) is not None: 

267 raise ValueError( 

268 f"Another entity with name {class_name} already exists in the current namespace." 

269 ) 

270 

271 cls = SfgClass( 

272 class_name, 

273 self._cursor.current_namespace, 

274 class_keyword=keyword, 

275 bases=bases, 

276 ) 

277 self._cursor.add_entity(cls) 

278 

279 def sequencer( 

280 *args: ( 

281 SfgClassComposer.VisibilityBlockSequencer 

282 | SfgMethodSequencer 

283 | SfgClassComposer.ConstructorBuilder 

284 | VarLike 

285 | str 

286 ), 

287 ): 

288 default_vis_sequencer = SfgClassComposer.VisibilityBlockSequencer( 

289 SfgVisibility.DEFAULT 

290 ) 

291 

292 def argfilter(arg): 

293 return not isinstance(arg, SfgClassComposer.VisibilityBlockSequencer) 

294 

295 default_vis_args = takewhile( 

296 argfilter, 

297 args, 

298 ) 

299 default_block = default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore 

300 vis_blocks: list[SfgVisibilityBlock] = [] 

301 

302 for arg in dropwhile(argfilter, args): 

303 if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer): 

304 vis_blocks.append(arg._resolve(self._ctx, cls)) 

305 else: 

306 raise SfgException( 

307 "Composer Syntax Error: " 

308 "Cannot add members with default visibility after a visibility block." 

309 ) 

310 

311 self._cursor.write_header(SfgClassBody(cls, default_block, vis_blocks)) 

312 

313 return sequencer 

314 

315 def _struct_from_numpy_dtype( 

316 self, struct_name: str, dtype: np.dtype, add_constructor: bool = True 

317 ): 

318 fields = dtype.fields 

319 if fields is None: 

320 raise SfgException(f"Numpy dtype {dtype} is not a structured type.") 

321 

322 members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = [] 

323 if add_constructor: 

324 ctor = self.constructor() 

325 members.append(ctor) 

326 

327 for member_name, type_info in fields.items(): 

328 member_type = create_type(type_info[0]) 

329 

330 member = SfgVar(member_name, member_type) 

331 members.append(member) 

332 

333 if add_constructor: 

334 arg = SfgVar(f"{member_name}_", member_type) 

335 ctor.add_param(arg) 

336 ctor.init(member)(arg) 

337 

338 return self.struct( 

339 struct_name, 

340 )(*members)