Coverage for src/pystencilssfg/emission/file_printer.py: 88%

130 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-04 07:16 +0000

1from __future__ import annotations 

2from textwrap import indent 

3 

4from pystencils.backend.emission import CAstPrinter 

5 

6from ..ir import ( 

7 SfgSourceFile, 

8 SfgSourceFileType, 

9 SfgNamespaceBlock, 

10 SfgEntityDecl, 

11 SfgEntityDef, 

12 SfgKernelHandle, 

13 SfgFunction, 

14 SfgClassMember, 

15 SfgMethod, 

16 SfgMemberVariable, 

17 SfgConstructor, 

18 SfgClass, 

19 SfgClassBody, 

20 SfgVisibilityBlock, 

21 SfgVisibility, 

22) 

23from ..ir.syntax import SfgNamespaceElement, SfgClassBodyElement 

24from ..config import CodeStyle 

25 

26 

27class SfgFilePrinter: 

28 def __init__(self, code_style: CodeStyle) -> None: 

29 self._code_style = code_style 

30 self._indent_width = code_style.get_option("indent_width") 

31 

32 def __call__(self, file: SfgSourceFile) -> str: 

33 code = "" 

34 

35 if file.prelude: 

36 comment = "/**\n" 

37 comment += indent(file.prelude, " * ", predicate=lambda _: True) 

38 comment += " */\n\n" 

39 

40 code += comment 

41 

42 if file.file_type == SfgSourceFileType.HEADER: 

43 code += "#pragma once\n\n" 

44 

45 for header in file.includes: 

46 incl = str(header) if header.system_header else f'"{str(header)}"' 

47 code += f"#include {incl}\n" 

48 

49 if file.includes: 

50 code += "\n" 

51 

52 # Here begins the actual code 

53 code += "\n\n".join(self.visit(elem) for elem in file.elements) 

54 code += "\n" 

55 

56 return code 

57 

58 def visit( 

59 self, elem: SfgNamespaceElement | SfgClassBodyElement, inclass: bool = False 

60 ) -> str: 

61 match elem: 

62 case str(): 

63 return elem 

64 case SfgNamespaceBlock(_, elements, label): 

65 code = f"namespace {label} { \n" 

66 code += self._code_style.indent( 

67 "\n\n".join(self.visit(e) for e in elements) 

68 ) 

69 code += f"\n} // namespace {label}" 

70 return code 

71 case SfgEntityDecl(entity): 

72 return self.visit_decl(entity, inclass) 

73 case SfgEntityDef(entity): 

74 return self.visit_defin(entity, inclass) 

75 case SfgClassBody(): 

76 return self.visit_defin(elem, inclass) 

77 case _: 

78 assert False, "illegal code element" 

79 

80 def visit_decl( 

81 self, 

82 declared_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, 

83 inclass: bool = False, 

84 ) -> str: 

85 match declared_entity: 

86 case SfgKernelHandle(kernel): 

87 kernel_printer = CAstPrinter( 

88 indent_width=self._indent_width, 

89 func_prefix="inline" if declared_entity.inline else None, 

90 ) 

91 return kernel_printer.print_signature(kernel) + ";" 

92 

93 case SfgFunction(name, _, params) | SfgMethod(name, _, params): 

94 return self._func_signature(declared_entity, inclass) + ";" 

95 

96 case SfgConstructor(cls, params): 

97 params_str = ", ".join( 

98 f"{param.dtype.c_string()} {param.name}" for param in params 

99 ) 

100 return f"{cls.name}({params_str});" 

101 

102 case SfgMemberVariable(name, dtype): 

103 return f"{dtype.c_string()} {name};" 

104 

105 case SfgClass(kwd, name): 

106 return f"{str(kwd)} {name};" 

107 

108 case _: 

109 assert False, f"unsupported declared entity: {declared_entity}" 

110 

111 def visit_defin( 

112 self, 

113 defined_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClassBody, 

114 inclass: bool = False, 

115 ) -> str: 

116 match defined_entity: 

117 case SfgKernelHandle(kernel): 

118 kernel_printer = CAstPrinter( 

119 indent_width=self._indent_width, 

120 func_prefix="inline" if defined_entity.inline else None, 

121 ) 

122 return kernel_printer(kernel) 

123 

124 case SfgFunction(name, tree, params) | SfgMethod(name, tree, params): 

125 sig = self._func_signature(defined_entity, inclass) 

126 body = tree.get_code(self._code_style) 

127 body = "\n{\n" + self._code_style.indent(body) + "\n}" 

128 return sig + body 

129 

130 case SfgConstructor(cls, params): 

131 params_str = ", ".join( 

132 f"{param.dtype.c_string()} {param.name}" for param in params 

133 ) 

134 

135 code = "" 

136 if not inclass: 

137 code += f"{cls.name}::" 

138 code += f"{cls.name} ({params_str})" 

139 

140 inits: list[str] = [] 

141 for var, args in defined_entity.initializers: 

142 args_str = ", ".join(str(arg) for arg in args) 

143 inits.append(f"{str(var)}({args_str})") 

144 

145 if inits: 

146 code += "\n:" + ",\n".join(inits) 

147 

148 code += "\n{\n" + self._code_style.indent(defined_entity.body) + "\n}" 

149 return code 

150 

151 case SfgMemberVariable(name, dtype): 

152 code = dtype.c_string() 

153 if not inclass: 

154 code += f" {defined_entity.owning_class.name}::" 

155 code += f" {name}" 

156 if defined_entity.default_init is not None: 

157 args_str = ", ".join( 

158 str(expr) for expr in defined_entity.default_init 

159 ) 

160 code += "{" + args_str + "}" 

161 code += ";" 

162 return code 

163 

164 case SfgClassBody(cls, vblocks): 

165 code = f"{cls.class_keyword} {cls.name}" 

166 if cls.base_classes: 

167 code += " : " + ", ".join(cls.base_classes) 

168 code += " {\n" 

169 vblocks_str = [self._visibility_block(b) for b in vblocks] 

170 code += "\n\n".join(vblocks_str) 

171 code += "\n};\n" 

172 return code 

173 

174 case _: 

175 assert False, f"unsupported defined entity: {defined_entity}" 

176 

177 def _visibility_block(self, vblock: SfgVisibilityBlock): 

178 prefix = ( 

179 f"{vblock.visibility}:\n" 

180 if vblock.visibility != SfgVisibility.DEFAULT 

181 else "" 

182 ) 

183 elements = [self.visit(elem, inclass=True) for elem in vblock.elements] 

184 return prefix + self._code_style.indent("\n".join(elements)) 

185 

186 def _func_signature(self, func: SfgFunction | SfgMethod, inclass: bool): 

187 code = "" 

188 

189 if func.attributes: 

190 code += "[[" + ", ".join(func.attributes) + "]]" 

191 

192 if func.inline and not inclass: 

193 code += "inline " 

194 

195 if isinstance(func, SfgMethod) and inclass: 

196 if func.static: 

197 code += "static " 

198 if func.virtual: 

199 code += "virtual " 

200 

201 if func.constexpr: 

202 code += "constexpr " 

203 

204 code += func.return_type.c_string() + " " 

205 params_str = ", ".join( 

206 f"{param.dtype.c_string()} {param.name}" for param in func.parameters 

207 ) 

208 if isinstance(func, SfgMethod) and not inclass: 

209 code += f"{func.owning_class.name}::" 

210 code += f"{func.name}({params_str})" 

211 

212 if isinstance(func, SfgMethod): 

213 if func.const: 

214 code += " const" 

215 if func.override and inclass: 

216 code += " override" 

217 

218 return code