From 63534369461bb5ac7fc2ed8b662a7d7c9243d05d Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 4 Apr 2024 14:19:42 +0200
Subject: [PATCH] Add printing support and tests for relations and boolean
 operators.

 - Add boolean ops and relations to CAstPrinter
 - Add test cases for precedence
 - Refactor precedence table to exactly reflect C++ reference
---
 src/pystencils/backend/emission.py   | 79 +++++++++++++++++++++-------
 tests/nbackend/test_code_printing.py | 60 ++++++++++++++++++++-
 2 files changed, 119 insertions(+), 20 deletions(-)

diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index aa5f853a7..135fb3d38 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -35,6 +35,15 @@ from .ast.expressions import (
     PsSubscript,
     PsSymbolExpr,
     PsVectorArrayAccess,
+    PsAnd,
+    PsOr,
+    PsNot,
+    PsEq,
+    PsNe,
+    PsGt,
+    PsLt,
+    PsGe,
+    PsLe,
 )
 
 from .symbols import PsSymbol
@@ -67,32 +76,41 @@ class Ops(Enum):
     See also https://en.cppreference.com/w/cpp/language/operator_precedence
     """
 
-    Weakest = (17 - 17, LR.Middle)
+    Call = (2, LR.Left)
+    Subscript = (2, LR.Left)
+    Lookup = (2, LR.Left)
 
-    BitwiseOr = (17 - 13, LR.Left)
+    Neg = (3, LR.Right)
+    Not = (3, LR.Right)
+    AddressOf = (3, LR.Right)
+    Deref = (3, LR.Right)
+    Cast = (3, LR.Right)
 
-    BitwiseXor = (17 - 12, LR.Left)
+    Mul = (5, LR.Left)
+    Div = (5, LR.Left)
+    Rem = (5, LR.Left)
 
-    BitwiseAnd = (17 - 11, LR.Left)
+    Add = (6, LR.Left)
+    Sub = (6, LR.Left)
 
-    LeftShift = (17 - 7, LR.Left)
-    RightShift = (17 - 7, LR.Left)
+    LeftShift = (7, LR.Left)
+    RightShift = (7, LR.Left)
 
-    Add = (17 - 6, LR.Left)
-    Sub = (17 - 6, LR.Left)
+    RelOp = (9, LR.Left)
 
-    Mul = (17 - 5, LR.Left)
-    Div = (17 - 5, LR.Left)
-    Rem = (17 - 5, LR.Left)
+    EqOp = (10, LR.Left)
 
-    Neg = (17 - 3, LR.Right)
-    AddressOf = (17 - 3, LR.Right)
-    Deref = (17 - 3, LR.Right)
-    Cast = (17 - 3, LR.Right)
+    BitwiseAnd = (11, LR.Left)
 
-    Call = (17 - 2, LR.Left)
-    Subscript = (17 - 2, LR.Left)
-    Lookup = (17 - 2, LR.Left)
+    BitwiseXor = (12, LR.Left)
+
+    BitwiseOr = (13, LR.Left)
+
+    LogicAnd = (14, LR.Left)
+
+    LogicOr = (15, LR.Left)
+
+    Weakest = (17, LR.Middle)
 
     def __init__(self, pred: int, assoc: LR) -> None:
         self.precedence = pred
@@ -125,7 +143,7 @@ class PrinterCtx:
         return self.branch_stack[-1]
 
     def parenthesize(self, expr: str, next_operator: Ops) -> str:
-        if next_operator.precedence < self.current_op.precedence:
+        if next_operator.precedence > self.current_op.precedence:
             return f"({expr})"
         elif (
             next_operator.precedence == self.current_op.precedence
@@ -274,6 +292,13 @@ class CAstPrinter:
 
                 return pc.parenthesize(f"-{operand_code}", Ops.Neg)
 
+            case PsNot(operand):
+                pc.push_op(Ops.Not, LR.Right)
+                operand_code = self.visit(operand, pc)
+                pc.pop_op()
+
+                return pc.parenthesize(f"!{operand_code}", Ops.Neg)
+
             case PsDeref(operand):
                 pc.push_op(Ops.Deref, LR.Right)
                 operand_code = self.visit(operand, pc)
@@ -339,5 +364,21 @@ class CAstPrinter:
                 return ("^", Ops.BitwiseXor)
             case PsBitwiseOr():
                 return ("|", Ops.BitwiseOr)
+            case PsAnd():
+                return ("&&", Ops.LogicAnd)
+            case PsOr():
+                return ("||", Ops.LogicOr)
+            case PsEq():
+                return ("==", Ops.EqOp)
+            case PsNe():
+                return ("!=", Ops.EqOp)
+            case PsGt():
+                return (">", Ops.RelOp)
+            case PsGe():
+                return (">=", Ops.RelOp)
+            case PsLt():
+                return ("<", Ops.RelOp)
+            case PsLe():
+                return ("<=", Ops.RelOp)
             case _:
                 assert False
diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py
index c8294c6dd..1fc6821d7 100644
--- a/tests/nbackend/test_code_printing.py
+++ b/tests/nbackend/test_code_printing.py
@@ -6,7 +6,7 @@ from pystencils.backend.kernelfunction import KernelFunction
 from pystencils.backend.symbols import PsSymbol
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer
-from pystencils.types.quick import Fp, SInt, UInt
+from pystencils.types.quick import Fp, SInt, UInt, Bool
 from pystencils.backend.emission import CAstPrinter
 
 
@@ -99,3 +99,61 @@ def test_printing_integer_functions():
     )
     code = cprint(expr)
     assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i"
+
+
+def test_logical_precedence():
+    from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr
+
+    p, q, r = [PsExpression.make(PsSymbol(x, Bool())) for x in "pqr"]
+    true = PsExpression.make(PsConstant(True, Bool()))
+    false = PsExpression.make(PsConstant(False, Bool()))
+    cprint = CAstPrinter()
+
+    expr = PsNot(PsAnd(p, PsOr(q, r)))
+    code = cprint(expr)
+    assert code == "!(p && (q || r))"
+
+    expr = PsAnd(PsAnd(p, q), PsAnd(q, r))
+    code = cprint(expr)
+    assert code == "p && q && (q && r)"
+
+    expr = PsOr(PsAnd(true, p), PsOr(PsAnd(false, PsNot(q)), PsAnd(r, p)))
+    code = cprint(expr)
+    assert code == "true && p || (false && !q || r && p)"
+
+    expr = PsAnd(PsOr(PsNot(p), PsNot(q)), PsNot(PsOr(true, false)))
+    code = cprint(expr)
+    assert code == "(!p || !q) && !(true || false)"
+
+
+def test_relations_precedence():
+    from pystencils.backend.ast.expressions import (
+        PsNot,
+        PsAnd,
+        PsOr,
+        PsEq,
+        PsNe,
+        PsLt,
+        PsGt,
+        PsLe,
+        PsGe,
+    )
+
+    x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"]
+    cprint = CAstPrinter()
+
+    expr = PsAnd(PsEq(x, y), PsLe(y, z))
+    code = cprint(expr)
+    assert code == "x == y && y <= z"
+
+    expr = PsOr(PsLt(x, y), PsLt(y, z))
+    code = cprint(expr)
+    assert code == "x < y || y < z"
+
+    expr = PsAnd(PsNot(PsGe(x, y)), PsNot(PsLe(y, z)))
+    code = cprint(expr)
+    assert code == "!(x >= y) && !(y <= z)"
+
+    expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z)))
+    code = cprint(expr)
+    assert code == "x != y || !(y > z)"
-- 
GitLab