Coverage for src/pystencilssfg/config.py: 97%

157 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2 

3from argparse import ArgumentParser, BooleanOptionalAction 

4 

5from types import ModuleType 

6from typing import Any, Sequence, Callable 

7from dataclasses import dataclass 

8from os import path 

9from importlib import util as iutil 

10from pathlib import Path 

11 

12from pystencils.codegen.config import ConfigBase, Option, BasicOption, Category 

13 

14from .lang import HeaderFile 

15 

16 

17class SfgConfigException(Exception): ... # noqa: E701 

18 

19 

20@dataclass 

21class FileExtensions(ConfigBase): 

22 """BasicOption category containing output file extensions.""" 

23 

24 header: BasicOption[str] = BasicOption("hpp") 

25 """File extension for generated header file.""" 

26 

27 impl: BasicOption[str] = BasicOption("cpp") 

28 """File extension for generated implementation file.""" 

29 

30 @header.validate 

31 @impl.validate 

32 def _validate_extension(self, ext: str | None) -> str | None: 

33 if ext is not None and ext[0] == ".": 

34 return ext[1:] 

35 

36 return ext 

37 

38 

39@dataclass 

40class CodeStyle(ConfigBase): 

41 """Options affecting the code style used by the source file generator.""" 

42 

43 indent_width: BasicOption[int] = BasicOption(2) 

44 """The number of spaces successively nested blocks should be indented with""" 

45 

46 includes_sorting_key: BasicOption[Callable[[HeaderFile], Any]] = BasicOption() 

47 """Key function that will be used to sort ``#include`` statements in generated files. 

48 

49 Pystencils-sfg will instruct clang-tidy to forego include sorting if this option is set. 

50 """ 

51 

52 # TODO possible future options: 

53 # - newline before opening { 

54 # - trailing return types 

55 

56 def indent(self, s: str): 

57 from textwrap import indent 

58 

59 prefix = " " * self.get_option("indent_width") 

60 return indent(s, prefix) 

61 

62 

63@dataclass 

64class ClangFormatOptions(ConfigBase): 

65 """Options affecting the invocation of ``clang-format`` for automatic code formatting.""" 

66 

67 code_style: BasicOption[str] = BasicOption("file") 

68 """Code style to be used by clang-format. Passed verbatim to ``--style`` argument of the clang-format CLI. 

69 

70 Similar to clang-format itself, the default value is ``file``, such that a ``.clang-format`` file found in the build 

71 tree will automatically be used. 

72 """ 

73 

74 force: BasicOption[bool] = BasicOption(False) 

75 """If set to ``True``, abort code generation if ``clang-format`` binary cannot be found.""" 

76 

77 skip: BasicOption[bool] = BasicOption(False) 

78 """If set to ``True``, skip formatting using ``clang-format``.""" 

79 

80 binary: BasicOption[str] = BasicOption("clang-format") 

81 """Path to the clang-format executable""" 

82 

83 @force.validate 

84 def _validate_force(self, val: bool) -> bool: 

85 if val and self.skip: 

86 raise SfgConfigException( 

87 "Cannot set both `clang_format.force` and `clang_format.skip` at the same time" 

88 ) 

89 return val 

90 

91 @skip.validate 

92 def _validate_skip(self, val: bool) -> bool: 

93 if val and self.force: 

94 raise SfgConfigException( 

95 "Cannot set both `clang_format.force` and `clang_format.skip` at the same time" 

96 ) 

97 return val 

98 

99 

100class _GlobalNamespace: ... # noqa: E701 

101 

102 

103GLOBAL_NAMESPACE = _GlobalNamespace() 

104"""Indicates the C++ global namespace.""" 

105 

106 

107@dataclass 

108class SfgConfig(ConfigBase): 

109 """Configuration options for the `SourceFileGenerator`.""" 

110 

111 extensions: Category[FileExtensions] = Category(FileExtensions()) 

112 """File extensions of the generated files 

113 

114 Options in this category: 

115 .. autosummary:: 

116 FileExtensions.header 

117 FileExtensions.impl 

118 """ 

119 

120 header_only: BasicOption[bool] = BasicOption(False) 

121 """If set to `True`, generate only a header file. 

122 

123 This will cause all definitions to be generated ``inline``. 

124 """ 

125 

126 outer_namespace: BasicOption[str | _GlobalNamespace] = BasicOption(GLOBAL_NAMESPACE) 

127 """The outermost namespace in the generated file. May be a valid C++ nested namespace qualifier 

128 (like ``a::b::c``) or `GLOBAL_NAMESPACE` if no outer namespace should be generated. 

129 

130 .. autosummary:: 

131 GLOBAL_NAMESPACE 

132 """ 

133 

134 codestyle: Category[CodeStyle] = Category(CodeStyle()) 

135 """Options affecting the code style emitted by pystencils-sfg. 

136 

137 Options in this category: 

138 .. autosummary:: 

139 CodeStyle.indent_width 

140 """ 

141 

142 clang_format: Category[ClangFormatOptions] = Category(ClangFormatOptions()) 

143 """Options governing the code style used by the code generator 

144 

145 Options in this category: 

146 .. autosummary:: 

147 ClangFormatOptions.code_style 

148 ClangFormatOptions.force 

149 ClangFormatOptions.skip 

150 ClangFormatOptions.binary 

151 """ 

152 

153 output_directory: Option[Path, str | Path] = Option(Path(".")) 

154 """Directory to which the generated files should be written.""" 

155 

156 @output_directory.validate 

157 def _validate_output_directory(self, pth: str | Path) -> Path: 

158 return Path(pth) 

159 

160 def _get_output_files(self, basename: str): 

161 output_dir: Path = self.get_option("output_directory") 

162 

163 header_ext = self.extensions.get_option("header") 

164 impl_ext = self.extensions.get_option("impl") 

165 output_files = [output_dir / f"{basename}.{header_ext}"] 

166 header_only = self.get_option("header_only") 

167 

168 if not header_only: 

169 assert impl_ext is not None 

170 output_files.append(output_dir / f"{basename}.{impl_ext}") 

171 

172 return tuple(output_files) 

173 

174 

175class CommandLineParameters: 

176 @staticmethod 

177 def add_args_to_parser(parser: ArgumentParser): 

178 config_group = parser.add_argument_group("Configuration") 

179 

180 config_group.add_argument( 

181 "--sfg-output-dir", type=str, default=None, dest="output_directory" 

182 ) 

183 config_group.add_argument( 

184 "--sfg-file-extensions", 

185 type=str, 

186 default=None, 

187 dest="file_extensions", 

188 help="Comma-separated list of file extensions", 

189 ) 

190 config_group.add_argument( 

191 "--sfg-header-only", 

192 action=BooleanOptionalAction, 

193 dest="header_only", 

194 help="Generate only a header file.", 

195 ) 

196 config_group.add_argument( 

197 "--sfg-config-module", type=str, default=None, dest="config_module_path" 

198 ) 

199 

200 return parser 

201 

202 def __init__(self, args) -> None: 

203 self._cl_config_module_path: str | None = args.config_module_path 

204 

205 self._cl_header_only: bool | None = args.header_only 

206 self._cl_output_dir: str | None = args.output_directory 

207 

208 if args.file_extensions is not None: 

209 file_extentions = list(args.file_extensions.split(",")) 

210 h_ext, impl_ext = self._get_file_extensions(file_extentions) 

211 self._cl_header_ext = h_ext 

212 self._cl_impl_ext = impl_ext 

213 else: 

214 self._cl_header_ext = None 

215 self._cl_impl_ext = None 

216 

217 self._config_module: ModuleType | None 

218 if self._cl_config_module_path is not None: 

219 self._config_module = self._import_config_module( 

220 self._cl_config_module_path 

221 ) 

222 else: 

223 self._config_module = None 

224 

225 @property 

226 def configuration_module(self) -> ModuleType | None: 

227 return self._config_module 

228 

229 def get_config(self) -> SfgConfig: 

230 cfg = SfgConfig() 

231 if self._config_module is not None and hasattr( 

232 self._config_module, "configure_sfg" 

233 ): 

234 self._config_module.configure_sfg(cfg) 

235 

236 if self._cl_header_only is not None: 

237 cfg.header_only = self._cl_header_only 

238 if self._cl_header_ext is not None: 

239 cfg.extensions.header = self._cl_header_ext 

240 if self._cl_impl_ext is not None: 

241 cfg.extensions.impl = self._cl_impl_ext 

242 if self._cl_output_dir is not None: 

243 cfg.output_directory = self._cl_output_dir 

244 

245 return cfg 

246 

247 def find_conflicts(self, cfg: SfgConfig): 

248 for name, mine, theirs in ( 

249 ("header_only", self._cl_header_only, cfg.header_only), 

250 ("extensions.header", self._cl_header_ext, cfg.extensions.header), 

251 ("extensions.impl", self._cl_impl_ext, cfg.extensions.impl), 

252 ("output_directory", self._cl_output_dir, cfg.output_directory), 

253 ): 

254 if mine is not None and theirs is not None and mine != theirs: 

255 raise SfgConfigException( 

256 f"Conflicting values given for option {name} on command line and inside generator script.\n" 

257 f" Value on command-line: {name}", 

258 f" Value in script: {name}", 

259 ) 

260 

261 def get_project_info(self) -> Any: 

262 if self._config_module is not None and hasattr( 

263 self._config_module, "project_info" 

264 ): 

265 return self._config_module.project_info() 

266 else: 

267 return None 

268 

269 def _get_file_extensions(self, extensions: Sequence[str]): 

270 h_ext = None 

271 src_ext = None 

272 

273 extensions = tuple(ext.strip() for ext in extensions) 

274 extensions = tuple((ext[1:] if ext[0] == "." else ext) for ext in extensions) 

275 

276 HEADER_FILE_EXTENSIONS = {"h", "hpp", "hxx", "h++", "cuh"} 

277 IMPL_FILE_EXTENSIONS = {"c", "cpp", "cxx", "c++", "cu", "hip"} 

278 

279 for ext in extensions: 

280 if ext in HEADER_FILE_EXTENSIONS: 

281 if h_ext is not None: 

282 raise SfgConfigException( 

283 "Multiple header file extensions specified." 

284 ) 

285 h_ext = ext 

286 elif ext in IMPL_FILE_EXTENSIONS: 

287 if src_ext is not None: 

288 raise SfgConfigException( 

289 "Multiple source file extensions specified." 

290 ) 

291 src_ext = ext 

292 else: 

293 raise SfgConfigException( 

294 f"Invalid file extension: Don't know what to do with '.{ext}'" 

295 ) 

296 

297 return h_ext, src_ext 

298 

299 def _import_config_module(self, module_path: str) -> ModuleType: 

300 cfg_modulename = path.splitext(path.split(module_path)[1])[0] 

301 

302 cfg_spec = iutil.spec_from_file_location(cfg_modulename, module_path) 

303 

304 if cfg_spec is None: 

305 raise SfgConfigException( 

306 f"Unable to import configuration module {module_path}", 

307 ) 

308 

309 config_module = iutil.module_from_spec(cfg_spec) 

310 cfg_spec.loader.exec_module(config_module) # type: ignore 

311 return config_module