Coverage for src/pystencilssfg/ir/entities.py: 87%
270 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
3from dataclasses import dataclass
4from abc import ABC
5from enum import Enum, auto
6from typing import (
7 TYPE_CHECKING,
8 Sequence,
9 Generator,
10)
11from itertools import chain
13from pystencils import Field
14from pystencils.codegen import Kernel
15from pystencils.types import PsType, PsCustomType
17from ..lang import SfgVar, SfgKernelParamVar, void, ExprLike
18from ..exceptions import SfgException
20if TYPE_CHECKING:
21 from . import SfgCallTreeNode
24# =========================================================================================================
25#
26# SEMANTICAL ENTITIES
27#
28# These classes model *code entities*, which represent *semantic components* of the generated files.
29#
30# =========================================================================================================
33class SfgCodeEntity:
34 """Base class for code entities.
36 Each code entity has a name and an optional enclosing namespace.
37 """
39 def __init__(self, name: str, parent_namespace: SfgNamespace) -> None:
40 self._name = name
41 self._namespace: SfgNamespace = parent_namespace
43 @property
44 def name(self) -> str:
45 """Name of this entity"""
46 return self._name
48 @property
49 def fqname(self) -> str:
50 """Fully qualified name of this entity"""
51 if not isinstance(self._namespace, SfgGlobalNamespace):
52 return self._namespace.fqname + "::" + self._name
53 else:
54 return self._name
56 @property
57 def parent_namespace(self) -> SfgNamespace | None:
58 """Parent namespace of this entity"""
59 return self._namespace
62class SfgNamespace(SfgCodeEntity):
63 """A C++ namespace.
65 Each namespace has a name and a parent; its fully qualified name is given as
66 ``<parent.name>::<name>``.
68 Args:
69 name: Local name of this namespace
70 parent: Parent namespace enclosing this namespace
71 """
73 def __init__(self, name: str, parent_namespace: SfgNamespace) -> None:
74 super().__init__(name, parent_namespace)
76 self._entities: dict[str, SfgCodeEntity] = dict()
78 def get_entity(self, qual_name: str) -> SfgCodeEntity | None:
79 """Find an entity with the given qualified name within this namespace.
81 If ``qual_name`` contains any qualifying delimiters ``::``,
82 each component but the last is interpreted as a namespace.
83 """
84 tokens = qual_name.split("::", 1)
85 match tokens:
86 case [entity_name]:
87 return self._entities.get(entity_name, None)
88 case [nspace, remaining_qualname]:
89 sub_nspace = self._entities.get(nspace, None)
90 if sub_nspace is not None:
91 if not isinstance(sub_nspace, SfgNamespace):
92 raise KeyError(
93 f"Unable to find entity {qual_name} in namespace {self._name}: "
94 f"Entity {nspace} is not a namespace."
95 )
96 return sub_nspace.get_entity(remaining_qualname)
97 else:
98 return None
99 case _:
100 assert False, "unreachable code"
102 def add_entity(self, entity: SfgCodeEntity):
103 if entity.name in self._entities:
104 raise ValueError(
105 f"Another entity with the name {entity.fqname} already exists"
106 )
107 self._entities[entity.name] = entity
109 def get_child_namespace(self, qual_name: str):
110 if not qual_name:
111 raise ValueError("Anonymous namespaces are not supported")
113 # Find the namespace by qualified lookup ...
114 namespace = self.get_entity(qual_name)
115 if namespace is not None:
116 if not type(namespace) is SfgNamespace:
117 raise ValueError(f"Entity {qual_name} exists, but is not a namespace")
118 else:
119 # ... or create it
120 tokens = qual_name.split("::")
121 namespace = self
122 for tok in tokens:
123 namespace = SfgNamespace(tok, namespace)
125 return namespace
128class SfgGlobalNamespace(SfgNamespace):
129 """The C++ global namespace."""
131 def __init__(self) -> None:
132 super().__init__("", self)
134 @property
135 def fqname(self) -> str:
136 return ""
139class SfgKernelHandle(SfgCodeEntity):
140 """Handle to a pystencils kernel."""
142 __match_args__ = ("kernel", "parameters")
144 def __init__(
145 self,
146 name: str,
147 namespace: SfgKernelNamespace,
148 kernel: Kernel,
149 inline: bool = False,
150 ):
151 super().__init__(name, namespace)
153 self._kernel = kernel
154 self._parameters = [SfgKernelParamVar(p) for p in kernel.parameters]
156 self._inline: bool = inline
158 self._scalar_params: set[SfgVar] = set()
159 self._fields: set[Field] = set()
161 for param in self._parameters:
162 if param.wrapped.is_field_parameter:
163 self._fields |= set(param.wrapped.fields)
164 else:
165 self._scalar_params.add(param)
167 @property
168 def parameters(self) -> Sequence[SfgKernelParamVar]:
169 """Parameters to this kernel"""
170 return self._parameters
172 @property
173 def scalar_parameters(self) -> set[SfgVar]:
174 """Scalar parameters to this kernel"""
175 return self._scalar_params
177 @property
178 def fields(self):
179 """Fields accessed by this kernel"""
180 return self._fields
182 @property
183 def kernel(self) -> Kernel:
184 """Underlying pystencils kernel object"""
185 return self._kernel
187 @property
188 def inline(self) -> bool:
189 return self._inline
192class SfgKernelNamespace(SfgNamespace):
193 """A namespace grouping together a number of kernels."""
195 def __init__(self, name: str, parent: SfgNamespace):
196 super().__init__(name, parent)
197 self._kernels: dict[str, SfgKernelHandle] = dict()
199 @property
200 def name(self):
201 return self._name
203 @property
204 def kernels(self) -> tuple[SfgKernelHandle, ...]:
205 return tuple(self._kernels.values())
207 def find_kernel(self, name: str) -> SfgKernelHandle | None:
208 return self._kernels.get(name, None)
210 def add_kernel(self, kernel: SfgKernelHandle):
211 if kernel.name in self._kernels:
212 raise ValueError(
213 f"Duplicate kernels: A kernel called {kernel.name} already exists "
214 f"in namespace {self.fqname}"
215 )
216 self._kernels[kernel.name] = kernel
219@dataclass(frozen=True, match_args=False)
220class CommonFunctionProperties:
221 tree: SfgCallTreeNode
222 parameters: tuple[SfgVar, ...]
223 return_type: PsType
224 inline: bool
225 constexpr: bool
226 attributes: Sequence[str]
228 @staticmethod
229 def collect_params(tree: SfgCallTreeNode, required_params: Sequence[SfgVar] | None):
230 from .postprocessing import CallTreePostProcessing
232 param_collector = CallTreePostProcessing()
233 params_set = param_collector(tree).function_params
235 if required_params is not None:
236 if not (params_set <= set(required_params)):
237 extras = params_set - set(required_params)
238 raise SfgException(
239 "Extraenous function parameters: "
240 f"Found free variables {extras} that were not listed in manually specified function parameters."
241 )
242 parameters = tuple(required_params)
243 else:
244 parameters = tuple(sorted(params_set, key=lambda p: p.name))
246 return parameters
249class SfgFunction(SfgCodeEntity, CommonFunctionProperties):
250 """A free function."""
252 __match_args__ = ("name", "tree", "parameters", "return_type") # type: ignore
254 def __init__(
255 self,
256 name: str,
257 namespace: SfgNamespace,
258 tree: SfgCallTreeNode,
259 return_type: PsType = void,
260 inline: bool = False,
261 constexpr: bool = False,
262 attributes: Sequence[str] = (),
263 required_params: Sequence[SfgVar] | None = None,
264 ):
265 super().__init__(name, namespace)
267 parameters = self.collect_params(tree, required_params)
269 CommonFunctionProperties.__init__(
270 self,
271 tree,
272 parameters,
273 return_type,
274 inline,
275 constexpr,
276 attributes,
277 )
280class SfgVisibility(Enum):
281 """Visibility qualifiers of C++"""
283 DEFAULT = auto()
284 PRIVATE = auto()
285 PROTECTED = auto()
286 PUBLIC = auto()
288 def __str__(self) -> str:
289 match self:
290 case SfgVisibility.DEFAULT:
291 return ""
292 case SfgVisibility.PRIVATE:
293 return "private"
294 case SfgVisibility.PROTECTED:
295 return "protected"
296 case SfgVisibility.PUBLIC:
297 return "public"
300class SfgClassKeyword(Enum):
301 """Class keywords of C++"""
303 STRUCT = auto()
304 CLASS = auto()
306 def __str__(self) -> str:
307 match self:
308 case SfgClassKeyword.STRUCT:
309 return "struct"
310 case SfgClassKeyword.CLASS:
311 return "class"
314class SfgClassMember(ABC):
315 """Base class for class member entities"""
317 def __init__(self, cls: SfgClass) -> None:
318 self._cls: SfgClass = cls
319 self._visibility: SfgVisibility | None = None
321 @property
322 def owning_class(self) -> SfgClass:
323 if self._cls is None:
324 raise SfgException(f"{self} is not bound to a class.")
325 return self._cls
327 @property
328 def visibility(self) -> SfgVisibility:
329 if self._visibility is None:
330 raise SfgException(
331 f"{self} is not bound to a class and therefore has no visibility."
332 )
333 return self._visibility
336class SfgMemberVariable(SfgVar, SfgClassMember):
337 """Variable that is a field of a class"""
339 def __init__(
340 self,
341 name: str,
342 dtype: PsType,
343 cls: SfgClass,
344 default_init: tuple[ExprLike, ...] | None = None,
345 ):
346 SfgVar.__init__(self, name, dtype)
347 SfgClassMember.__init__(self, cls)
348 self._default_init = default_init
350 @property
351 def default_init(self) -> tuple[ExprLike, ...] | None:
352 return self._default_init
355class SfgMethod(SfgClassMember, CommonFunctionProperties):
356 """Instance method of a class"""
358 __match_args__ = ("name", "tree", "parameters", "return_type") # type: ignore
360 def __init__(
361 self,
362 name: str,
363 cls: SfgClass,
364 tree: SfgCallTreeNode,
365 return_type: PsType = void,
366 inline: bool = False,
367 const: bool = False,
368 static: bool = False,
369 constexpr: bool = False,
370 virtual: bool = False,
371 override: bool = False,
372 attributes: Sequence[str] = (),
373 required_params: Sequence[SfgVar] | None = None,
374 ):
375 super().__init__(cls)
377 self._name = name
378 self._static = static
379 self._const = const
380 self._virtual = virtual
381 self._override = override
383 parameters = self.collect_params(tree, required_params)
385 CommonFunctionProperties.__init__(
386 self,
387 tree,
388 parameters,
389 return_type,
390 inline,
391 constexpr,
392 attributes,
393 )
395 @property
396 def name(self) -> str:
397 return self._name
399 @property
400 def static(self) -> bool:
401 return self._static
403 @property
404 def const(self) -> bool:
405 return self._const
407 @property
408 def virtual(self) -> bool:
409 return self._virtual
411 @property
412 def override(self) -> bool:
413 return self._override
416class SfgConstructor(SfgClassMember):
417 """Constructor of a class"""
419 __match_args__ = ("owning_class", "parameters", "initializers", "body")
421 def __init__(
422 self,
423 cls: SfgClass,
424 parameters: Sequence[SfgVar] = (),
425 initializers: Sequence[tuple[SfgVar | str, tuple[ExprLike, ...]]] = (),
426 body: str = "",
427 ):
428 super().__init__(cls)
429 self._parameters = tuple(parameters)
430 self._initializers = tuple(initializers)
431 self._body = body
433 @property
434 def parameters(self) -> tuple[SfgVar, ...]:
435 return self._parameters
437 @property
438 def initializers(self) -> tuple[tuple[SfgVar | str, tuple[ExprLike, ...]], ...]:
439 return self._initializers
441 @property
442 def body(self) -> str:
443 return self._body
446class SfgClass(SfgCodeEntity):
447 """A C++ class."""
449 __match_args__ = ("class_keyword", "name")
451 def __init__(
452 self,
453 name: str,
454 namespace: SfgNamespace,
455 class_keyword: SfgClassKeyword = SfgClassKeyword.CLASS,
456 bases: Sequence[str] = (),
457 ):
458 if isinstance(bases, str):
459 raise ValueError("Base classes must be given as a sequence.")
461 super().__init__(name, namespace)
463 self._class_keyword = class_keyword
464 self._bases_classes = tuple(bases)
466 self._constructors: list[SfgConstructor] = []
467 self._methods: list[SfgMethod] = []
468 self._member_vars: dict[str, SfgMemberVariable] = dict()
470 @property
471 def src_type(self) -> PsType:
472 # TODO: Use CppTypeFactory instead
473 return PsCustomType(self._name)
475 @property
476 def base_classes(self) -> tuple[str, ...]:
477 return self._bases_classes
479 @property
480 def class_keyword(self) -> SfgClassKeyword:
481 return self._class_keyword
483 def members(
484 self, visibility: SfgVisibility | None = None
485 ) -> Generator[SfgClassMember, None, None]:
486 if visibility is None:
487 yield from chain(
488 self._constructors, self._methods, self._member_vars.values()
489 )
490 else:
491 yield from filter(lambda m: m.visibility == visibility, self.members())
493 def member_variables(
494 self, visibility: SfgVisibility | None = None
495 ) -> Generator[SfgMemberVariable, None, None]:
496 if visibility is not None:
497 yield from filter(
498 lambda m: m.visibility == visibility, self._member_vars.values()
499 )
500 else:
501 yield from self._member_vars.values()
503 def constructors(
504 self, visibility: SfgVisibility | None = None
505 ) -> Generator[SfgConstructor, None, None]:
506 if visibility is not None:
507 yield from filter(lambda m: m.visibility == visibility, self._constructors)
508 else:
509 yield from self._constructors
511 def methods(
512 self, visibility: SfgVisibility | None = None
513 ) -> Generator[SfgMethod, None, None]:
514 if visibility is not None:
515 yield from filter(lambda m: m.visibility == visibility, self._methods)
516 else:
517 yield from self._methods
519 def add_member(self, member: SfgClassMember, vis: SfgVisibility):
520 if isinstance(member, SfgConstructor):
521 self._constructors.append(member)
522 elif isinstance(member, SfgMemberVariable):
523 self._add_member_variable(member)
524 elif isinstance(member, SfgMethod):
525 self._methods.append(member)
526 else:
527 raise SfgException(f"{member} is not a valid class member.")
529 def _add_member_variable(self, variable: SfgMemberVariable):
530 if variable.name in self._member_vars:
531 raise SfgException(
532 f"Duplicate field name {variable.name} in class {self._name}"
533 )
535 self._member_vars[variable.name] = variable