Skip to content
Snippets Groups Projects
Select Git revision
  • 63534369461bb5ac7fc2ed8b662a7d7c9243d05d
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

emission.py

Blame
  • Frederik Hennig's avatar
    Frederik Hennig authored
     - Add boolean ops and relations to CAstPrinter
     - Add test cases for precedence
     - Refactor precedence table to exactly reflect C++ reference
    63534369
    History
    emission.py 10.97 KiB
    from __future__ import annotations
    from enum import Enum
    
    from .ast.structural import (
        PsAstNode,
        PsBlock,
        PsStatement,
        PsDeclaration,
        PsAssignment,
        PsLoop,
        PsConditional,
        PsComment,
    )
    
    from .ast.expressions import (
        PsAdd,
        PsAddressOf,
        PsArrayInitList,
        PsBinOp,
        PsBitwiseAnd,
        PsBitwiseOr,
        PsBitwiseXor,
        PsCall,
        PsCast,
        PsConstantExpr,
        PsDeref,
        PsDiv,
        PsIntDiv,
        PsLeftShift,
        PsLookup,
        PsMul,
        PsNeg,
        PsRightShift,
        PsSub,
        PsSubscript,
        PsSymbolExpr,
        PsVectorArrayAccess,
        PsAnd,
        PsOr,
        PsNot,
        PsEq,
        PsNe,
        PsGt,
        PsLt,
        PsGe,
        PsLe,
    )
    
    from .symbols import PsSymbol
    from ..types import PsScalarType, PsArrayType
    
    from .kernelfunction import KernelFunction
    
    
    __all__ = ["emit_code", "CAstPrinter"]
    
    
    def emit_code(kernel: KernelFunction):
        printer = CAstPrinter()
        return printer(kernel)
    
    
    class EmissionError(Exception):
        """Indicates a fatal error during code printing"""
    
    
    class LR(Enum):
        Left = 0
        Right = 1
        Middle = 2
    
    
    class Ops(Enum):
        """Operator precedence and associativity in C/C++.
    
        See also https://en.cppreference.com/w/cpp/language/operator_precedence
        """
    
        Call = (2, LR.Left)
        Subscript = (2, LR.Left)
        Lookup = (2, LR.Left)
    
        Neg = (3, LR.Right)
        Not = (3, LR.Right)
        AddressOf = (3, LR.Right)
        Deref = (3, LR.Right)
        Cast = (3, LR.Right)
    
        Mul = (5, LR.Left)
        Div = (5, LR.Left)
        Rem = (5, LR.Left)
    
        Add = (6, LR.Left)
        Sub = (6, LR.Left)
    
        LeftShift = (7, LR.Left)
        RightShift = (7, LR.Left)
    
        RelOp = (9, LR.Left)
    
        EqOp = (10, LR.Left)
    
        BitwiseAnd = (11, LR.Left)
    
        BitwiseXor = (12, LR.Left)
    
        BitwiseOr = (13, LR.Left)
    
        LogicAnd = (14, LR.Left)
    
        LogicOr = (15, LR.Left)
    
        Weakest = (17, LR.Middle)
    
        def __init__(self, pred: int, assoc: LR) -> None:
            self.precedence = pred
            self.assoc = assoc
    
    
    class PrinterCtx:
        def __init__(self) -> None:
            self.operator_stack = [Ops.Weakest]
            self.branch_stack = [LR.Middle]
            self.indent_level = 0
    
        def push_op(self, operator: Ops, branch: LR):
            self.operator_stack.append(operator)
            self.branch_stack.append(branch)
    
        def pop_op(self) -> None:
            self.operator_stack.pop()
            self.branch_stack.pop()
    
        def switch_branch(self, branch: LR):
            self.branch_stack[-1] = branch
    
        @property
        def current_op(self) -> Ops:
            return self.operator_stack[-1]
    
        @property
        def current_branch(self) -> LR:
            return self.branch_stack[-1]
    
        def parenthesize(self, expr: str, next_operator: Ops) -> str:
            if next_operator.precedence > self.current_op.precedence:
                return f"({expr})"
            elif (
                next_operator.precedence == self.current_op.precedence
                and self.current_branch != self.current_op.assoc
            ):
                return f"({expr})"
    
            return expr
    
        def indent(self, line: str) -> str:
            return " " * self.indent_level + line
    
    
    class CAstPrinter:
        def __init__(self, indent_width=3):
            self._indent_width = indent_width
    
        def __call__(self, obj: PsAstNode | KernelFunction) -> str:
            if isinstance(obj, KernelFunction):
                params_str = ", ".join(
                    f"{p.dtype.c_string()} {p.name}" for p in obj.parameters
                )
                decl = f"FUNC_PREFIX void {obj.name} ({params_str})"
                body_code = self.visit(obj.body, PrinterCtx())
                return f"{decl}\n{body_code}"
            else:
                return self.visit(obj, PrinterCtx())
    
        def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
            match node:
                case PsBlock(statements):
                    if not statements:
                        return pc.indent("{ }")
    
                    pc.indent_level += self._indent_width
                    interior = "\n".join(self.visit(stmt, pc) for stmt in statements) + "\n"
                    pc.indent_level -= self._indent_width
                    return pc.indent("{\n") + interior + pc.indent("}\n")
    
                case PsStatement(expr):
                    return pc.indent(f"{self.visit(expr, pc)};")
    
                case PsDeclaration(lhs, rhs):
                    lhs_symb = node.declared_symbol
                    lhs_code = self._symbol_decl(lhs_symb)
                    rhs_code = self.visit(rhs, pc)
    
                    return pc.indent(f"{lhs_code} = {rhs_code};")
    
                case PsAssignment(lhs, rhs):
                    lhs_code = self.visit(lhs, pc)
                    rhs_code = self.visit(rhs, pc)
                    return pc.indent(f"{lhs_code} = {rhs_code};")
    
                case PsLoop(ctr, start, stop, step, body):
                    ctr_symbol = ctr.symbol
    
                    start_code = self.visit(start, pc)
                    stop_code = self.visit(stop, pc)
                    step_code = self.visit(step, pc)
                    body_code = self.visit(body, pc)
    
                    code = (
                        f"for({ctr_symbol.dtype} {ctr_symbol.name} = {start_code};"
                        + f" {ctr.symbol.name} < {stop_code};"
                        + f" {ctr.symbol.name} += {step_code})\n"
                        + body_code
                    )
                    return pc.indent(code)
    
                case PsConditional(condition, branch_true, branch_false):
                    cond_code = self.visit(condition, pc)
                    then_code = self.visit(branch_true, pc)
    
                    code = f"if({cond_code})\n{then_code}"
    
                    if branch_false is not None:
                        else_code = self.visit(branch_false, pc)
                        code += f"\nelse\n{else_code}"
    
                    return pc.indent(code)
    
                case PsComment(lines):
                    lines_list = list(lines)
                    lines_list[0] = "/* " + lines_list[0]
                    for i in range(1, len(lines_list)):
                        lines_list[i] = "   " + lines_list[i]
                    lines_list[-1] = lines_list[-1] + " */"
                    return pc.indent("\n".join(lines_list))
    
                case PsSymbolExpr(symbol):
                    return symbol.name
    
                case PsConstantExpr(constant):
                    dtype = constant.get_dtype()
                    if not isinstance(dtype, PsScalarType):
                        raise EmissionError(
                            "Cannot print literals for non-scalar constants."
                        )
    
                    return dtype.create_literal(constant.value)
    
                case PsVectorArrayAccess():
                    raise EmissionError("Cannot print vectorized array accesses")
    
                case PsSubscript(base, index):
                    pc.push_op(Ops.Subscript, LR.Left)
                    base_code = self.visit(base, pc)
                    pc.pop_op()
    
                    pc.push_op(Ops.Weakest, LR.Middle)
                    index_code = self.visit(index, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"{base_code}[{index_code}]", Ops.Subscript)
    
                case PsLookup(aggr, member_name):
                    pc.push_op(Ops.Lookup, LR.Left)
                    aggr_code = self.visit(aggr, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"{aggr_code}.{member_name}", Ops.Lookup)
    
                case PsCall(function, args):
                    pc.push_op(Ops.Weakest, LR.Middle)
                    args_string = ", ".join(self.visit(arg, pc) for arg in args)
                    pc.pop_op()
    
                    return pc.parenthesize(f"{function.name}({args_string})", Ops.Call)
    
                case PsBinOp(op1, op2):
                    op_char, op = self._char_and_op(node)
    
                    pc.push_op(op, LR.Left)
                    op1_code = self.visit(op1, pc)
                    pc.switch_branch(LR.Right)
                    op2_code = self.visit(op2, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"{op1_code} {op_char} {op2_code}", op)
    
                case PsNeg(operand):
                    pc.push_op(Ops.Neg, LR.Right)
                    operand_code = self.visit(operand, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"-{operand_code}", Ops.Neg)
    
                case PsNot(operand):
                    pc.push_op(Ops.Not, LR.Right)
                    operand_code = self.visit(operand, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"!{operand_code}", Ops.Neg)
    
                case PsDeref(operand):
                    pc.push_op(Ops.Deref, LR.Right)
                    operand_code = self.visit(operand, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"*{operand_code}", Ops.Deref)
    
                case PsAddressOf(operand):
                    pc.push_op(Ops.AddressOf, LR.Right)
                    operand_code = self.visit(operand, pc)
                    pc.pop_op()
    
                    return pc.parenthesize(f"&{operand_code}", Ops.AddressOf)
    
                case PsCast(target_type, operand):
                    pc.push_op(Ops.Cast, LR.Right)
                    operand_code = self.visit(operand, pc)
                    pc.pop_op()
    
                    type_str = target_type.c_string()
                    return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast)
    
                case PsArrayInitList(items):
                    pc.push_op(Ops.Weakest, LR.Middle)
                    items_str = ", ".join(self.visit(item, pc) for item in items)
                    pc.pop_op()
                    return "{ " + items_str + " }"
    
                case _:
                    raise NotImplementedError(f"Don't know how to print {node}")
    
        def _symbol_decl(self, symb: PsSymbol):
            dtype = symb.get_dtype()
    
            array_dims = []
            while isinstance(dtype, PsArrayType):
                array_dims.append(dtype.length)
                dtype = dtype.base_type
    
            code = f"{dtype.c_string()} {symb.name}"
            for d in array_dims:
                code += f"[{str(d) if d is not None else ''}]"
    
            return code
    
        def _char_and_op(self, node: PsBinOp) -> tuple[str, Ops]:
            match node:
                case PsAdd():
                    return ("+", Ops.Add)
                case PsSub():
                    return ("-", Ops.Sub)
                case PsMul():
                    return ("*", Ops.Mul)
                case PsDiv() | PsIntDiv():
                    return ("/", Ops.Div)
                case PsLeftShift():
                    return ("<<", Ops.LeftShift)
                case PsRightShift():
                    return (">>", Ops.RightShift)
                case PsBitwiseAnd():
                    return ("&", Ops.BitwiseAnd)
                case PsBitwiseXor():
                    return ("^", Ops.BitwiseXor)
                case PsBitwiseOr():
                    return ("|", Ops.BitwiseOr)
                case PsAnd():
                    return ("&&", Ops.LogicAnd)
                case PsOr():
                    return ("||", Ops.LogicOr)
                case PsEq():
                    return ("==", Ops.EqOp)
                case PsNe():
                    return ("!=", Ops.EqOp)
                case PsGt():
                    return (">", Ops.RelOp)
                case PsGe():
                    return (">=", Ops.RelOp)
                case PsLt():
                    return ("<", Ops.RelOp)
                case PsLe():
                    return ("<=", Ops.RelOp)
                case _:
                    assert False