Coverage for src/pystencilssfg/composer/class_composer.py: 82%
147 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
2from typing import Sequence
3from itertools import takewhile, dropwhile
4import numpy as np
6from pystencils.types import create_type
8from ..context import SfgContext, SfgCursor
9from ..lang import (
10 VarLike,
11 ExprLike,
12 asvar,
13 SfgVar,
14)
16from ..ir import (
17 SfgCallTreeNode,
18 SfgClass,
19 SfgConstructor,
20 SfgMethod,
21 SfgMemberVariable,
22 SfgClassKeyword,
23 SfgVisibility,
24 SfgVisibilityBlock,
25 SfgEntityDecl,
26 SfgEntityDef,
27 SfgClassBody,
28)
29from ..exceptions import SfgException
31from .mixin import SfgComposerMixIn
32from .basic_composer import (
33 make_sequence,
34 SequencerArg,
35 SfgFunctionSequencerBase,
36)
39class SfgMethodSequencer(SfgFunctionSequencerBase):
40 def __init__(self, cursor: SfgCursor, name: str) -> None:
41 super().__init__(cursor, name)
43 self._const: bool = False
44 self._static: bool = False
45 self._virtual: bool = False
46 self._override: bool = False
48 self._tree: SfgCallTreeNode
50 def const(self):
51 """Mark this method as ``const``."""
52 self._const = True
53 return self
55 def static(self):
56 """Mark this method as ``static``."""
57 self._static = True
58 return self
60 def virtual(self):
61 """Mark this method as ``virtual``."""
62 self._virtual = True
63 return self
65 def override(self):
66 """Mark this method as ``override``."""
67 self._override = True
68 return self
70 def __call__(self, *args: SequencerArg):
71 self._tree = make_sequence(*args)
72 return self
74 def _resolve(self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock):
75 method = SfgMethod(
76 self._name,
77 cls,
78 self._tree,
79 return_type=self._return_type,
80 inline=self._inline,
81 const=self._const,
82 static=self._static,
83 constexpr=self._constexpr,
84 virtual=self._virtual,
85 override=self._override,
86 attributes=self._attributes,
87 required_params=self._params,
88 )
89 cls.add_member(method, vis_block.visibility)
91 if self._inline:
92 vis_block.elements.append(SfgEntityDef(method))
93 else:
94 vis_block.elements.append(SfgEntityDecl(method))
95 ctx._cursor.write_impl(SfgEntityDef(method))
98class SfgClassComposer(SfgComposerMixIn):
99 """Composer for classes and structs.
102 This class cannot be instantiated on its own but must be mixed in with
103 :class:`SfgBasicComposer`.
104 Its interface is exposed by :class:`SfgComposer`.
105 """
107 class VisibilityBlockSequencer:
108 """Represent a visibility block in the composer syntax.
110 Returned by `private`, `public`, and `protected`.
111 """
113 def __init__(self, visibility: SfgVisibility):
114 self._visibility = visibility
115 self._args: tuple[
116 SfgMethodSequencer
117 | SfgClassComposer.ConstructorBuilder
118 | VarLike
119 | str,
120 ...,
121 ]
123 def __call__(
124 self,
125 *args: (
126 SfgMethodSequencer | SfgClassComposer.ConstructorBuilder | VarLike | str
127 ),
128 ):
129 self._args = args
130 return self
132 def _resolve(self, ctx: SfgContext, cls: SfgClass) -> SfgVisibilityBlock:
133 vis_block = SfgVisibilityBlock(self._visibility)
134 for arg in self._args:
135 match arg:
136 case SfgMethodSequencer() | SfgClassComposer.ConstructorBuilder():
137 arg._resolve(ctx, cls, vis_block)
138 case str():
139 vis_block.elements.append(arg)
140 case _:
141 var = asvar(arg)
142 member_var = SfgMemberVariable(var.name, var.dtype, cls)
143 cls.add_member(member_var, vis_block.visibility)
144 vis_block.elements.append(SfgEntityDef(member_var))
145 return vis_block
147 class ConstructorBuilder:
148 """Composer syntax for constructor building.
150 Returned by `constructor`.
151 """
153 def __init__(self, *params: VarLike):
154 self._params = list(asvar(p) for p in params)
155 self._initializers: list[tuple[SfgVar | str, tuple[ExprLike, ...]]] = []
156 self._body: str | None = None
158 def add_param(self, param: VarLike, at: int | None = None):
159 if at is None:
160 self._params.append(asvar(param))
161 else:
162 self._params.insert(at, asvar(param))
164 @property
165 def parameters(self) -> list[SfgVar]:
166 return self._params
168 def init(self, var: VarLike | str):
169 """Add an initialization expression to the constructor's initializer list."""
171 member = var if isinstance(var, str) else asvar(var)
173 def init_sequencer(*args: ExprLike):
174 self._initializers.append((member, args))
175 return self
177 return init_sequencer
179 def body(self, body: str):
180 """Define the constructor body"""
181 if self._body is not None:
182 raise SfgException("Multiple definitions of constructor body.")
183 self._body = body
184 return self
186 def _resolve(
187 self, ctx: SfgContext, cls: SfgClass, vis_block: SfgVisibilityBlock
188 ):
189 ctor = SfgConstructor(
190 cls,
191 parameters=self._params,
192 initializers=self._initializers,
193 body=self._body if self._body is not None else "",
194 )
196 cls.add_member(ctor, vis_block.visibility)
197 vis_block.elements.append(SfgEntityDef(ctor))
199 def klass(self, class_name: str, bases: Sequence[str] = ()):
200 """Create a class and add it to the underlying context.
202 Args:
203 class_name: Name of the class
204 bases: List of base classes
205 """
206 return self._class(class_name, SfgClassKeyword.CLASS, bases)
208 def struct(self, class_name: str, bases: Sequence[str] = ()):
209 """Create a struct and add it to the underlying context.
211 Args:
212 class_name: Name of the struct
213 bases: List of base classes
214 """
215 return self._class(class_name, SfgClassKeyword.STRUCT, bases)
217 def numpy_struct(self, name: str, dtype: np.dtype, add_constructor: bool = True):
218 """Add a numpy structured data type as a C++ struct
220 Returns:
221 The created class object
222 """
223 return self._struct_from_numpy_dtype(name, dtype, add_constructor)
225 @property
226 def public(self) -> SfgClassComposer.VisibilityBlockSequencer:
227 """Create a `public` visibility block in a class body"""
228 return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PUBLIC)
230 @property
231 def protected(self) -> SfgClassComposer.VisibilityBlockSequencer:
232 """Create a `protected` visibility block in a class or struct body"""
233 return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PROTECTED)
235 @property
236 def private(self) -> SfgClassComposer.VisibilityBlockSequencer:
237 """Create a `private` visibility block in a class or struct body"""
238 return SfgClassComposer.VisibilityBlockSequencer(SfgVisibility.PRIVATE)
240 def constructor(self, *params: VarLike):
241 """In a class or struct body or visibility block, add a constructor.
243 Args:
244 params: List of constructor parameters
245 """
246 return SfgClassComposer.ConstructorBuilder(*params)
248 def method(self, name: str):
249 """In a class or struct body or visibility block, add a method.
250 The usage is similar to :any:`SfgBasicComposer.function`.
252 Args:
253 name: The method name
254 """
256 seq = SfgMethodSequencer(self._cursor, name)
257 if self._ctx.impl_file is None:
258 seq.inline()
259 return seq
261 # INTERNALS
263 def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]):
264 # TODO: Return a `CppClass` instance representing the generated class
266 if self._cursor.get_entity(class_name) is not None:
267 raise ValueError(
268 f"Another entity with name {class_name} already exists in the current namespace."
269 )
271 cls = SfgClass(
272 class_name,
273 self._cursor.current_namespace,
274 class_keyword=keyword,
275 bases=bases,
276 )
277 self._cursor.add_entity(cls)
279 def sequencer(
280 *args: (
281 SfgClassComposer.VisibilityBlockSequencer
282 | SfgMethodSequencer
283 | SfgClassComposer.ConstructorBuilder
284 | VarLike
285 | str
286 ),
287 ):
288 default_vis_sequencer = SfgClassComposer.VisibilityBlockSequencer(
289 SfgVisibility.DEFAULT
290 )
292 def argfilter(arg):
293 return not isinstance(arg, SfgClassComposer.VisibilityBlockSequencer)
295 default_vis_args = takewhile(
296 argfilter,
297 args,
298 )
299 default_block = default_vis_sequencer(*default_vis_args)._resolve(self._ctx, cls) # type: ignore
300 vis_blocks: list[SfgVisibilityBlock] = []
302 for arg in dropwhile(argfilter, args):
303 if isinstance(arg, SfgClassComposer.VisibilityBlockSequencer):
304 vis_blocks.append(arg._resolve(self._ctx, cls))
305 else:
306 raise SfgException(
307 "Composer Syntax Error: "
308 "Cannot add members with default visibility after a visibility block."
309 )
311 self._cursor.write_header(SfgClassBody(cls, default_block, vis_blocks))
313 return sequencer
315 def _struct_from_numpy_dtype(
316 self, struct_name: str, dtype: np.dtype, add_constructor: bool = True
317 ):
318 fields = dtype.fields
319 if fields is None:
320 raise SfgException(f"Numpy dtype {dtype} is not a structured type.")
322 members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = []
323 if add_constructor:
324 ctor = self.constructor()
325 members.append(ctor)
327 for member_name, type_info in fields.items():
328 member_type = create_type(type_info[0])
330 member = SfgVar(member_name, member_type)
331 members.append(member)
333 if add_constructor:
334 arg = SfgVar(f"{member_name}_", member_type)
335 ctor.add_param(arg)
336 ctor.init(member)(arg)
338 return self.struct(
339 struct_name,
340 )(*members)