Coverage for src/pystencilssfg/emission/file_printer.py: 88%
130 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-04 07:16 +0000
1from __future__ import annotations
2from textwrap import indent
4from pystencils.backend.emission import CAstPrinter
6from ..ir import (
7 SfgSourceFile,
8 SfgSourceFileType,
9 SfgNamespaceBlock,
10 SfgEntityDecl,
11 SfgEntityDef,
12 SfgKernelHandle,
13 SfgFunction,
14 SfgClassMember,
15 SfgMethod,
16 SfgMemberVariable,
17 SfgConstructor,
18 SfgClass,
19 SfgClassBody,
20 SfgVisibilityBlock,
21 SfgVisibility,
22)
23from ..ir.syntax import SfgNamespaceElement, SfgClassBodyElement
24from ..config import CodeStyle
27class SfgFilePrinter:
28 def __init__(self, code_style: CodeStyle) -> None:
29 self._code_style = code_style
30 self._indent_width = code_style.get_option("indent_width")
32 def __call__(self, file: SfgSourceFile) -> str:
33 code = ""
35 if file.prelude:
36 comment = "/**\n"
37 comment += indent(file.prelude, " * ", predicate=lambda _: True)
38 comment += " */\n\n"
40 code += comment
42 if file.file_type == SfgSourceFileType.HEADER:
43 code += "#pragma once\n\n"
45 for header in file.includes:
46 incl = str(header) if header.system_header else f'"{str(header)}"'
47 code += f"#include {incl}\n"
49 if file.includes:
50 code += "\n"
52 # Here begins the actual code
53 code += "\n\n".join(self.visit(elem) for elem in file.elements)
54 code += "\n"
56 return code
58 def visit(
59 self, elem: SfgNamespaceElement | SfgClassBodyElement, inclass: bool = False
60 ) -> str:
61 match elem:
62 case str():
63 return elem
64 case SfgNamespaceBlock(_, elements, label):
65 code = f"namespace {label} { \n"
66 code += self._code_style.indent(
67 "\n\n".join(self.visit(e) for e in elements)
68 )
69 code += f"\n} // namespace {label}"
70 return code
71 case SfgEntityDecl(entity):
72 return self.visit_decl(entity, inclass)
73 case SfgEntityDef(entity):
74 return self.visit_defin(entity, inclass)
75 case SfgClassBody():
76 return self.visit_defin(elem, inclass)
77 case _:
78 assert False, "illegal code element"
80 def visit_decl(
81 self,
82 declared_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass,
83 inclass: bool = False,
84 ) -> str:
85 match declared_entity:
86 case SfgKernelHandle(kernel):
87 kernel_printer = CAstPrinter(
88 indent_width=self._indent_width,
89 func_prefix="inline" if declared_entity.inline else None,
90 )
91 return kernel_printer.print_signature(kernel) + ";"
93 case SfgFunction(name, _, params) | SfgMethod(name, _, params):
94 return self._func_signature(declared_entity, inclass) + ";"
96 case SfgConstructor(cls, params):
97 params_str = ", ".join(
98 f"{param.dtype.c_string()} {param.name}" for param in params
99 )
100 return f"{cls.name}({params_str});"
102 case SfgMemberVariable(name, dtype):
103 return f"{dtype.c_string()} {name};"
105 case SfgClass(kwd, name):
106 return f"{str(kwd)} {name};"
108 case _:
109 assert False, f"unsupported declared entity: {declared_entity}"
111 def visit_defin(
112 self,
113 defined_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClassBody,
114 inclass: bool = False,
115 ) -> str:
116 match defined_entity:
117 case SfgKernelHandle(kernel):
118 kernel_printer = CAstPrinter(
119 indent_width=self._indent_width,
120 func_prefix="inline" if defined_entity.inline else None,
121 )
122 return kernel_printer(kernel)
124 case SfgFunction(name, tree, params) | SfgMethod(name, tree, params):
125 sig = self._func_signature(defined_entity, inclass)
126 body = tree.get_code(self._code_style)
127 body = "\n{\n" + self._code_style.indent(body) + "\n}"
128 return sig + body
130 case SfgConstructor(cls, params):
131 params_str = ", ".join(
132 f"{param.dtype.c_string()} {param.name}" for param in params
133 )
135 code = ""
136 if not inclass:
137 code += f"{cls.name}::"
138 code += f"{cls.name} ({params_str})"
140 inits: list[str] = []
141 for var, args in defined_entity.initializers:
142 args_str = ", ".join(str(arg) for arg in args)
143 inits.append(f"{str(var)}({args_str})")
145 if inits:
146 code += "\n:" + ",\n".join(inits)
148 code += "\n{\n" + self._code_style.indent(defined_entity.body) + "\n}"
149 return code
151 case SfgMemberVariable(name, dtype):
152 code = dtype.c_string()
153 if not inclass:
154 code += f" {defined_entity.owning_class.name}::"
155 code += f" {name}"
156 if defined_entity.default_init is not None:
157 args_str = ", ".join(
158 str(expr) for expr in defined_entity.default_init
159 )
160 code += "{" + args_str + "}"
161 code += ";"
162 return code
164 case SfgClassBody(cls, vblocks):
165 code = f"{cls.class_keyword} {cls.name}"
166 if cls.base_classes:
167 code += " : " + ", ".join(cls.base_classes)
168 code += " {\n"
169 vblocks_str = [self._visibility_block(b) for b in vblocks]
170 code += "\n\n".join(vblocks_str)
171 code += "\n};\n"
172 return code
174 case _:
175 assert False, f"unsupported defined entity: {defined_entity}"
177 def _visibility_block(self, vblock: SfgVisibilityBlock):
178 prefix = (
179 f"{vblock.visibility}:\n"
180 if vblock.visibility != SfgVisibility.DEFAULT
181 else ""
182 )
183 elements = [self.visit(elem, inclass=True) for elem in vblock.elements]
184 return prefix + self._code_style.indent("\n".join(elements))
186 def _func_signature(self, func: SfgFunction | SfgMethod, inclass: bool):
187 code = ""
189 if func.attributes:
190 code += "[[" + ", ".join(func.attributes) + "]]"
192 if func.inline and not inclass:
193 code += "inline "
195 if isinstance(func, SfgMethod) and inclass:
196 if func.static:
197 code += "static "
198 if func.virtual:
199 code += "virtual "
201 if func.constexpr:
202 code += "constexpr "
204 code += func.return_type.c_string() + " "
205 params_str = ", ".join(
206 f"{param.dtype.c_string()} {param.name}" for param in func.parameters
207 )
208 if isinstance(func, SfgMethod) and not inclass:
209 code += f"{func.owning_class.name}::"
210 code += f"{func.name}({params_str})"
212 if isinstance(func, SfgMethod):
213 if func.const:
214 code += " const"
215 if func.override and inclass:
216 code += " override"
218 return code