Coverage for src/pystencilssfg/lang/types.py: 94%
116 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
2from typing import Any, Iterable, Sequence, Mapping, TypeVar, Generic
3from abc import ABC
4from dataclasses import dataclass
5from itertools import chain
7import string
9from pystencils.types import PsType, PsPointerType, PsCustomType
10from .headers import HeaderFile
13class VoidType(PsType):
14 """C++ void type."""
16 def __init__(self, const: bool = False):
17 super().__init__(False)
19 def __args__(self) -> tuple[Any, ...]:
20 return ()
22 def c_string(self) -> str:
23 return "void"
25 def __repr__(self) -> str:
26 return "VoidType()"
29void = VoidType()
32class _TemplateArgFormatter(string.Formatter):
34 def format_field(self, arg, format_spec):
35 if isinstance(arg, PsType):
36 arg = arg.c_string()
37 return super().format_field(arg, format_spec)
39 def check_unused_args(
40 self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any]
41 ) -> None:
42 max_args_len: int = (
43 max((k for k in used_args if isinstance(k, int)), default=-1) + 1
44 )
45 if len(args) > max_args_len:
46 raise ValueError(
47 f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}"
48 )
50 extra_keys = set(kwargs.keys()) - used_args # type: ignore
51 if extra_keys:
52 raise ValueError(f"Extraneous keyword arguments: {extra_keys}")
55@dataclass(frozen=True)
56class _TemplateArgs:
57 pargs: tuple[Any, ...]
58 kwargs: tuple[tuple[str, Any], ...]
61class CppType(PsCustomType, ABC):
62 class_includes: frozenset[HeaderFile]
63 template_string: str
65 def __init__(self, *template_args, const: bool = False, **template_kwargs):
66 # Support for cloning CppTypes
67 if template_args and isinstance(template_args[0], _TemplateArgs):
68 assert not template_kwargs
69 targs = template_args[0]
70 pargs = targs.pargs
71 kwargs = dict(targs.kwargs)
72 else:
73 pargs = template_args
74 kwargs = template_kwargs
75 targs = _TemplateArgs(
76 pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0]))
77 )
79 formatter = _TemplateArgFormatter()
80 name = formatter.format(self.template_string, *pargs, **kwargs)
82 self._targs = targs
83 self._includes = self.class_includes
85 for arg in chain(pargs, kwargs.values()):
86 match arg:
87 case CppType():
88 self._includes |= arg.includes
89 case PsType():
90 self._includes |= {
91 HeaderFile.parse(h) for h in arg.required_headers
92 }
94 super().__init__(name, const=const)
96 def __args__(self) -> tuple[Any, ...]:
97 return (self._targs,)
99 @property
100 def includes(self) -> frozenset[HeaderFile]:
101 return self._includes
103 @property
104 def required_headers(self) -> set[str]:
105 return set(str(h) for h in self.class_includes)
108TypeClass_T = TypeVar("TypeClass_T", bound=CppType)
109"""Python type variable bound to `CppType`."""
112class CppTypeFactory(Generic[TypeClass_T]):
113 """Type Factory returned by `cpptype`."""
115 def __init__(self, tclass: type[TypeClass_T]) -> None:
116 self._type_class = tclass
118 @property
119 def includes(self) -> frozenset[HeaderFile]:
120 """Set of headers required by this factory's type"""
121 return self._type_class.class_includes
123 @property
124 def template_string(self) -> str:
125 """Template string of this factory's type"""
126 return self._type_class.template_string
128 def __str__(self) -> str:
129 return f"Factory for {self.template_string}` defined in {self.includes}"
131 def __repr__(self) -> str:
132 return f"CppTypeFactory({self.template_string}, includes={ {', '.join(str(i) for i in self.includes)} } )"
134 def __call__(self, *args, ref: bool = False, **kwargs) -> TypeClass_T | Ref:
135 """Create a type object of this factory's C++ type template.
137 Args:
138 args, kwargs: Positional and keyword arguments are forwarded to the template string formatter
139 ref: If ``True``, return a reference type
141 Returns:
142 An instantiated type object
143 """
145 obj = self._type_class(*args, **kwargs)
146 if ref:
147 return Ref(obj)
148 else:
149 return obj
152def cpptype(
153 template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
154) -> CppTypeFactory:
155 """Describe a C++ type template, associated with a set of required header files.
157 This function allows users to define C++ type templates using
158 `Python format string syntax <https://docs.python.org/3/library/string.html#formatstrings>`_.
159 The types may furthermore be annotated with a set of header files that must be included
160 in order to use the type.
162 >>> opt_template = lang.cpptype("std::optional< {T} >", "<optional>")
163 >>> opt_template.template_string
164 'std::optional< {T} >'
166 This function returns a `CppTypeFactory` object, which in turn can be called to create
167 an instance of the C++ type template.
168 Therein, the ``template_str`` argument is treated as a Python format string:
169 The positional and keyword arguments passed to the returned type factory are passed
170 through machinery that is based on `str.format` to produce the actual type name.
172 >>> int_option = opt_template(T="int")
173 >>> int_option.c_string().strip()
174 'std::optional< int >'
176 The factory may also create reference types when the ``ref=True`` is specified.
178 >>> int_option_ref = opt_template(T="int", ref=True)
179 >>> int_option_ref.c_string().strip()
180 'std::optional< int >&'
182 Args:
183 template_str: Format string defining the type template
184 include: Either the name of a header file, or a sequence of names of header files
186 Returns:
187 CppTypeFactory: A factory used to instantiate the type template
188 """
189 headers: list[str | HeaderFile]
191 if isinstance(include, (str, HeaderFile)):
192 headers = [
193 include,
194 ]
195 else:
196 headers = list(include)
198 class TypeClass(CppType):
199 template_string = template_str
200 class_includes = frozenset(HeaderFile.parse(h) for h in headers)
202 return CppTypeFactory[TypeClass](TypeClass)
205class Ref(PsType):
206 """C++ reference type."""
208 __match_args__ = "base_type"
210 def __init__(self, base_type: PsType, const: bool = False):
211 super().__init__(False)
212 self._base_type = base_type
214 def __args__(self) -> tuple[Any, ...]:
215 return (self.base_type,)
217 @property
218 def base_type(self) -> PsType:
219 return self._base_type
221 def c_string(self) -> str:
222 base_str = self.base_type.c_string()
223 return base_str + "&"
225 def __repr__(self) -> str:
226 return f"Ref({repr(self.base_type)})"
229def strip_ptr_ref(dtype: PsType):
230 match dtype:
231 case Ref():
232 return strip_ptr_ref(dtype.base_type)
233 case PsPointerType():
234 return strip_ptr_ref(dtype.base_type)
235 case _:
236 return dtype