Coverage for src/pystencilssfg/lang/types.py: 94%

116 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2from typing import Any, Iterable, Sequence, Mapping, TypeVar, Generic 

3from abc import ABC 

4from dataclasses import dataclass 

5from itertools import chain 

6 

7import string 

8 

9from pystencils.types import PsType, PsPointerType, PsCustomType 

10from .headers import HeaderFile 

11 

12 

13class VoidType(PsType): 

14 """C++ void type.""" 

15 

16 def __init__(self, const: bool = False): 

17 super().__init__(False) 

18 

19 def __args__(self) -> tuple[Any, ...]: 

20 return () 

21 

22 def c_string(self) -> str: 

23 return "void" 

24 

25 def __repr__(self) -> str: 

26 return "VoidType()" 

27 

28 

29void = VoidType() 

30 

31 

32class _TemplateArgFormatter(string.Formatter): 

33 

34 def format_field(self, arg, format_spec): 

35 if isinstance(arg, PsType): 

36 arg = arg.c_string() 

37 return super().format_field(arg, format_spec) 

38 

39 def check_unused_args( 

40 self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any] 

41 ) -> None: 

42 max_args_len: int = ( 

43 max((k for k in used_args if isinstance(k, int)), default=-1) + 1 

44 ) 

45 if len(args) > max_args_len: 

46 raise ValueError( 

47 f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}" 

48 ) 

49 

50 extra_keys = set(kwargs.keys()) - used_args # type: ignore 

51 if extra_keys: 

52 raise ValueError(f"Extraneous keyword arguments: {extra_keys}") 

53 

54 

55@dataclass(frozen=True) 

56class _TemplateArgs: 

57 pargs: tuple[Any, ...] 

58 kwargs: tuple[tuple[str, Any], ...] 

59 

60 

61class CppType(PsCustomType, ABC): 

62 class_includes: frozenset[HeaderFile] 

63 template_string: str 

64 

65 def __init__(self, *template_args, const: bool = False, **template_kwargs): 

66 # Support for cloning CppTypes 

67 if template_args and isinstance(template_args[0], _TemplateArgs): 

68 assert not template_kwargs 

69 targs = template_args[0] 

70 pargs = targs.pargs 

71 kwargs = dict(targs.kwargs) 

72 else: 

73 pargs = template_args 

74 kwargs = template_kwargs 

75 targs = _TemplateArgs( 

76 pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0])) 

77 ) 

78 

79 formatter = _TemplateArgFormatter() 

80 name = formatter.format(self.template_string, *pargs, **kwargs) 

81 

82 self._targs = targs 

83 self._includes = self.class_includes 

84 

85 for arg in chain(pargs, kwargs.values()): 

86 match arg: 

87 case CppType(): 

88 self._includes |= arg.includes 

89 case PsType(): 

90 self._includes |= { 

91 HeaderFile.parse(h) for h in arg.required_headers 

92 } 

93 

94 super().__init__(name, const=const) 

95 

96 def __args__(self) -> tuple[Any, ...]: 

97 return (self._targs,) 

98 

99 @property 

100 def includes(self) -> frozenset[HeaderFile]: 

101 return self._includes 

102 

103 @property 

104 def required_headers(self) -> set[str]: 

105 return set(str(h) for h in self.class_includes) 

106 

107 

108TypeClass_T = TypeVar("TypeClass_T", bound=CppType) 

109"""Python type variable bound to `CppType`.""" 

110 

111 

112class CppTypeFactory(Generic[TypeClass_T]): 

113 """Type Factory returned by `cpptype`.""" 

114 

115 def __init__(self, tclass: type[TypeClass_T]) -> None: 

116 self._type_class = tclass 

117 

118 @property 

119 def includes(self) -> frozenset[HeaderFile]: 

120 """Set of headers required by this factory's type""" 

121 return self._type_class.class_includes 

122 

123 @property 

124 def template_string(self) -> str: 

125 """Template string of this factory's type""" 

126 return self._type_class.template_string 

127 

128 def __str__(self) -> str: 

129 return f"Factory for {self.template_string}` defined in {self.includes}" 

130 

131 def __repr__(self) -> str: 

132 return f"CppTypeFactory({self.template_string}, includes={ {', '.join(str(i) for i in self.includes)} } )" 

133 

134 def __call__(self, *args, ref: bool = False, **kwargs) -> TypeClass_T | Ref: 

135 """Create a type object of this factory's C++ type template. 

136 

137 Args: 

138 args, kwargs: Positional and keyword arguments are forwarded to the template string formatter 

139 ref: If ``True``, return a reference type 

140 

141 Returns: 

142 An instantiated type object 

143 """ 

144 

145 obj = self._type_class(*args, **kwargs) 

146 if ref: 

147 return Ref(obj) 

148 else: 

149 return obj 

150 

151 

152def cpptype( 

153 template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = () 

154) -> CppTypeFactory: 

155 """Describe a C++ type template, associated with a set of required header files. 

156 

157 This function allows users to define C++ type templates using 

158 `Python format string syntax <https://docs.python.org/3/library/string.html#formatstrings>`_. 

159 The types may furthermore be annotated with a set of header files that must be included 

160 in order to use the type. 

161 

162 >>> opt_template = lang.cpptype("std::optional< {T} >", "<optional>") 

163 >>> opt_template.template_string 

164 'std::optional< {T} >' 

165 

166 This function returns a `CppTypeFactory` object, which in turn can be called to create 

167 an instance of the C++ type template. 

168 Therein, the ``template_str`` argument is treated as a Python format string: 

169 The positional and keyword arguments passed to the returned type factory are passed 

170 through machinery that is based on `str.format` to produce the actual type name. 

171 

172 >>> int_option = opt_template(T="int") 

173 >>> int_option.c_string().strip() 

174 'std::optional< int >' 

175 

176 The factory may also create reference types when the ``ref=True`` is specified. 

177 

178 >>> int_option_ref = opt_template(T="int", ref=True) 

179 >>> int_option_ref.c_string().strip() 

180 'std::optional< int >&' 

181 

182 Args: 

183 template_str: Format string defining the type template 

184 include: Either the name of a header file, or a sequence of names of header files 

185 

186 Returns: 

187 CppTypeFactory: A factory used to instantiate the type template 

188 """ 

189 headers: list[str | HeaderFile] 

190 

191 if isinstance(include, (str, HeaderFile)): 

192 headers = [ 

193 include, 

194 ] 

195 else: 

196 headers = list(include) 

197 

198 class TypeClass(CppType): 

199 template_string = template_str 

200 class_includes = frozenset(HeaderFile.parse(h) for h in headers) 

201 

202 return CppTypeFactory[TypeClass](TypeClass) 

203 

204 

205class Ref(PsType): 

206 """C++ reference type.""" 

207 

208 __match_args__ = "base_type" 

209 

210 def __init__(self, base_type: PsType, const: bool = False): 

211 super().__init__(False) 

212 self._base_type = base_type 

213 

214 def __args__(self) -> tuple[Any, ...]: 

215 return (self.base_type,) 

216 

217 @property 

218 def base_type(self) -> PsType: 

219 return self._base_type 

220 

221 def c_string(self) -> str: 

222 base_str = self.base_type.c_string() 

223 return base_str + "&" 

224 

225 def __repr__(self) -> str: 

226 return f"Ref({repr(self.base_type)})" 

227 

228 

229def strip_ptr_ref(dtype: PsType): 

230 match dtype: 

231 case Ref(): 

232 return strip_ptr_ref(dtype.base_type) 

233 case PsPointerType(): 

234 return strip_ptr_ref(dtype.base_type) 

235 case _: 

236 return dtype