From 63534369461bb5ac7fc2ed8b662a7d7c9243d05d Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 4 Apr 2024 14:19:42 +0200 Subject: [PATCH] Add printing support and tests for relations and boolean operators. - Add boolean ops and relations to CAstPrinter - Add test cases for precedence - Refactor precedence table to exactly reflect C++ reference --- src/pystencils/backend/emission.py | 79 +++++++++++++++++++++------- tests/nbackend/test_code_printing.py | 60 ++++++++++++++++++++- 2 files changed, 119 insertions(+), 20 deletions(-) diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index aa5f853a7..135fb3d38 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -35,6 +35,15 @@ from .ast.expressions import ( PsSubscript, PsSymbolExpr, PsVectorArrayAccess, + PsAnd, + PsOr, + PsNot, + PsEq, + PsNe, + PsGt, + PsLt, + PsGe, + PsLe, ) from .symbols import PsSymbol @@ -67,32 +76,41 @@ class Ops(Enum): See also https://en.cppreference.com/w/cpp/language/operator_precedence """ - Weakest = (17 - 17, LR.Middle) + Call = (2, LR.Left) + Subscript = (2, LR.Left) + Lookup = (2, LR.Left) - BitwiseOr = (17 - 13, LR.Left) + Neg = (3, LR.Right) + Not = (3, LR.Right) + AddressOf = (3, LR.Right) + Deref = (3, LR.Right) + Cast = (3, LR.Right) - BitwiseXor = (17 - 12, LR.Left) + Mul = (5, LR.Left) + Div = (5, LR.Left) + Rem = (5, LR.Left) - BitwiseAnd = (17 - 11, LR.Left) + Add = (6, LR.Left) + Sub = (6, LR.Left) - LeftShift = (17 - 7, LR.Left) - RightShift = (17 - 7, LR.Left) + LeftShift = (7, LR.Left) + RightShift = (7, LR.Left) - Add = (17 - 6, LR.Left) - Sub = (17 - 6, LR.Left) + RelOp = (9, LR.Left) - Mul = (17 - 5, LR.Left) - Div = (17 - 5, LR.Left) - Rem = (17 - 5, LR.Left) + EqOp = (10, LR.Left) - Neg = (17 - 3, LR.Right) - AddressOf = (17 - 3, LR.Right) - Deref = (17 - 3, LR.Right) - Cast = (17 - 3, LR.Right) + BitwiseAnd = (11, LR.Left) - Call = (17 - 2, LR.Left) - Subscript = (17 - 2, LR.Left) - Lookup = (17 - 2, LR.Left) + BitwiseXor = (12, LR.Left) + + BitwiseOr = (13, LR.Left) + + LogicAnd = (14, LR.Left) + + LogicOr = (15, LR.Left) + + Weakest = (17, LR.Middle) def __init__(self, pred: int, assoc: LR) -> None: self.precedence = pred @@ -125,7 +143,7 @@ class PrinterCtx: return self.branch_stack[-1] def parenthesize(self, expr: str, next_operator: Ops) -> str: - if next_operator.precedence < self.current_op.precedence: + if next_operator.precedence > self.current_op.precedence: return f"({expr})" elif ( next_operator.precedence == self.current_op.precedence @@ -274,6 +292,13 @@ class CAstPrinter: return pc.parenthesize(f"-{operand_code}", Ops.Neg) + case PsNot(operand): + pc.push_op(Ops.Not, LR.Right) + operand_code = self.visit(operand, pc) + pc.pop_op() + + return pc.parenthesize(f"!{operand_code}", Ops.Neg) + case PsDeref(operand): pc.push_op(Ops.Deref, LR.Right) operand_code = self.visit(operand, pc) @@ -339,5 +364,21 @@ class CAstPrinter: return ("^", Ops.BitwiseXor) case PsBitwiseOr(): return ("|", Ops.BitwiseOr) + case PsAnd(): + return ("&&", Ops.LogicAnd) + case PsOr(): + return ("||", Ops.LogicOr) + case PsEq(): + return ("==", Ops.EqOp) + case PsNe(): + return ("!=", Ops.EqOp) + case PsGt(): + return (">", Ops.RelOp) + case PsGe(): + return (">=", Ops.RelOp) + case PsLt(): + return ("<", Ops.RelOp) + case PsLe(): + return ("<=", Ops.RelOp) case _: assert False diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index c8294c6dd..1fc6821d7 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -6,7 +6,7 @@ from pystencils.backend.kernelfunction import KernelFunction from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer -from pystencils.types.quick import Fp, SInt, UInt +from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter @@ -99,3 +99,61 @@ def test_printing_integer_functions(): ) code = cprint(expr) assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i" + + +def test_logical_precedence(): + from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr + + p, q, r = [PsExpression.make(PsSymbol(x, Bool())) for x in "pqr"] + true = PsExpression.make(PsConstant(True, Bool())) + false = PsExpression.make(PsConstant(False, Bool())) + cprint = CAstPrinter() + + expr = PsNot(PsAnd(p, PsOr(q, r))) + code = cprint(expr) + assert code == "!(p && (q || r))" + + expr = PsAnd(PsAnd(p, q), PsAnd(q, r)) + code = cprint(expr) + assert code == "p && q && (q && r)" + + expr = PsOr(PsAnd(true, p), PsOr(PsAnd(false, PsNot(q)), PsAnd(r, p))) + code = cprint(expr) + assert code == "true && p || (false && !q || r && p)" + + expr = PsAnd(PsOr(PsNot(p), PsNot(q)), PsNot(PsOr(true, false))) + code = cprint(expr) + assert code == "(!p || !q) && !(true || false)" + + +def test_relations_precedence(): + from pystencils.backend.ast.expressions import ( + PsNot, + PsAnd, + PsOr, + PsEq, + PsNe, + PsLt, + PsGt, + PsLe, + PsGe, + ) + + x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"] + cprint = CAstPrinter() + + expr = PsAnd(PsEq(x, y), PsLe(y, z)) + code = cprint(expr) + assert code == "x == y && y <= z" + + expr = PsOr(PsLt(x, y), PsLt(y, z)) + code = cprint(expr) + assert code == "x < y || y < z" + + expr = PsAnd(PsNot(PsGe(x, y)), PsNot(PsLe(y, z))) + code = cprint(expr) + assert code == "!(x >= y) && !(y <= z)" + + expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z))) + code = cprint(expr) + assert code == "x != y || !(y > z)" -- GitLab