-
Richard Angersbach authoredRichard Angersbach authored
file_printer.py 7.98 KiB
from __future__ import annotations
from textwrap import indent
from pystencils.backend.emission import CAstPrinter
from ..ir import (
SfgSourceFile,
SfgSourceFileType,
SfgNamespaceBlock,
SfgEntityDecl,
SfgEntityDef,
SfgKernelHandle,
SfgFunction,
SfgClassMember,
SfgMethod,
SfgMemberVariable,
SfgConstructor,
SfgClass,
SfgClassBody,
SfgVisibilityBlock,
SfgVisibility,
)
from ..ir.syntax import SfgNamespaceElement, SfgClassBodyElement
from ..config import CodeStyle
class SfgFilePrinter:
def __init__(self, code_style: CodeStyle) -> None:
self._code_style = code_style
self._indent_width = code_style.get_option("indent_width")
def __call__(self, file: SfgSourceFile) -> str:
code = ""
if file.prelude:
comment = "/**\n"
comment += indent(file.prelude, " * ", predicate=lambda _: True)
comment += " */\n\n"
code += comment
if file.file_type == SfgSourceFileType.HEADER:
code += "#pragma once\n\n"
includes = ""
for header in file.includes:
incl = str(header) if header.system_header else f'"{str(header)}"'
includes += f"#include {incl}\n"
if file.file_type == SfgSourceFileType.HYBRID_HEADER:
hybrid_includes = ""
for header in file.hybrid_includes:
incl = str(header) if header.system_header else f'"{str(header)}"'
hybrid_includes += f"#include {incl}\n"
# include different headers and wrap around guard distinguishing C++/C compilations
code += f"""
#ifdef __cplusplus\n
{includes}
#else\n
{hybrid_includes}
#endif\n"""
else:
code += includes
if file.includes:
code += "\n"
# Here begins the actual code
code += "\n\n".join(self.visit(elem) for elem in file.elements)
code += "\n"
return code
def visit(
self, elem: SfgNamespaceElement | SfgClassBodyElement, inclass: bool = False
) -> str:
match elem:
case str():
return elem
case SfgNamespaceBlock(_, elements, label):
code = f"namespace {label} {{\n"
code += self._code_style.indent(
"\n\n".join(self.visit(e) for e in elements)
)
code += f"\n}} // namespace {label}"
return code
case SfgEntityDecl(entity):
return self.visit_decl(entity, inclass)
case SfgEntityDef(entity):
return self.visit_defin(entity, inclass)
case SfgClassBody():
return self.visit_defin(elem, inclass)
case _:
assert False, "illegal code element"
def visit_decl(
self,
declared_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass,
inclass: bool = False,
) -> str:
match declared_entity:
case SfgKernelHandle(kernel):
kernel_printer = CAstPrinter(
indent_width=self._indent_width,
func_prefix="inline" if declared_entity.inline else "",
)
return kernel_printer.print_signature(kernel) + ";"
case SfgFunction(name, _, params) | SfgMethod(name, _, params):
return self._func_signature(declared_entity, inclass) + ";"
case SfgConstructor(cls, params):
params_str = ", ".join(
f"{param.dtype.c_string()} {param.name}" for param in params
)
return f"{cls.name}({params_str});"
case SfgMemberVariable(name, dtype):
return f"{dtype.c_string()} {name};"
case SfgClass(kwd, name):
return f"{str(kwd)} {name};"
case _:
assert False, f"unsupported declared entity: {declared_entity}"
def visit_defin(
self,
defined_entity: SfgKernelHandle | SfgFunction | SfgClassMember | SfgClassBody,
inclass: bool = False,
) -> str:
match defined_entity:
case SfgKernelHandle(kernel):
kernel_printer = CAstPrinter(
indent_width=self._indent_width,
func_prefix="inline" if defined_entity.inline else None,
)
return kernel_printer(kernel)
case SfgFunction(name, tree, params) | SfgMethod(name, tree, params):
sig = self._func_signature(defined_entity, inclass)
body = tree.get_code(self._code_style)
body = "\n{\n" + self._code_style.indent(body) + "\n}"
return sig + body
case SfgConstructor(cls, params):
params_str = ", ".join(
f"{param.dtype.c_string()} {param.name}" for param in params
)
code = ""
if not inclass:
code += f"{cls.name}::"
code += f"{cls.name} ({params_str})"
inits: list[str] = []
for var, args in defined_entity.initializers:
args_str = ", ".join(str(arg) for arg in args)
inits.append(f"{str(var)}({args_str})")
if inits:
code += "\n:" + ",\n".join(inits)
code += "\n{\n" + self._code_style.indent(defined_entity.body) + "\n}"
return code
case SfgMemberVariable(name, dtype):
code = dtype.c_string()
if not inclass:
code += f" {defined_entity.owning_class.name}::"
code += f" {name}"
if defined_entity.default_init is not None:
args_str = ", ".join(
str(expr) for expr in defined_entity.default_init
)
code += "{" + args_str + "}"
code += ";"
return code
case SfgClassBody(cls, vblocks):
code = f"{cls.class_keyword} {cls.name}"
if cls.base_classes:
code += " : " + ", ".join(cls.base_classes)
code += " {\n"
vblocks_str = [self._visibility_block(b) for b in vblocks]
code += "\n\n".join(vblocks_str)
code += "\n};\n"
return code
case _:
assert False, f"unsupported defined entity: {defined_entity}"
def _visibility_block(self, vblock: SfgVisibilityBlock):
prefix = (
f"{vblock.visibility}:\n"
if vblock.visibility != SfgVisibility.DEFAULT
else ""
)
elements = [self.visit(elem, inclass=True) for elem in vblock.elements]
return prefix + self._code_style.indent("\n".join(elements))
def _func_signature(self, func: SfgFunction | SfgMethod, inclass: bool):
code = ""
if func.attributes:
code += "[[" + ", ".join(func.attributes) + "]]"
if func.inline and not inclass:
code += "inline "
if isinstance(func, SfgFunction) and func.externC and not inclass:
code += "EXTERNC "
if isinstance(func, SfgMethod) and inclass:
if func.static:
code += "static "
if func.virtual:
code += "virtual "
if func.constexpr:
code += "constexpr "
code += func.return_type.c_string() + " "
params_str = ", ".join(
f"{param.dtype.c_string()} {param.name}" for param in func.parameters
)
if isinstance(func, SfgMethod) and not inclass:
code += f"{func.owning_class.name}::"
code += f"{func.name}({params_str})"
if isinstance(func, SfgMethod):
if func.const:
code += " const"
if func.override and inclass:
code += " override"
return code