Skip to content
Snippets Groups Projects
Commit 63534369 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add printing support and tests for relations and boolean operators.

 - Add boolean ops and relations to CAstPrinter
 - Add test cases for precedence
 - Refactor precedence table to exactly reflect C++ reference
parent 565498ba
No related branches found
No related tags found
1 merge request!375Support for Boolean Operations and Relations
Pipeline #64914 passed
...@@ -35,6 +35,15 @@ from .ast.expressions import ( ...@@ -35,6 +35,15 @@ from .ast.expressions import (
PsSubscript, PsSubscript,
PsSymbolExpr, PsSymbolExpr,
PsVectorArrayAccess, PsVectorArrayAccess,
PsAnd,
PsOr,
PsNot,
PsEq,
PsNe,
PsGt,
PsLt,
PsGe,
PsLe,
) )
from .symbols import PsSymbol from .symbols import PsSymbol
...@@ -67,32 +76,41 @@ class Ops(Enum): ...@@ -67,32 +76,41 @@ class Ops(Enum):
See also https://en.cppreference.com/w/cpp/language/operator_precedence See also https://en.cppreference.com/w/cpp/language/operator_precedence
""" """
Weakest = (17 - 17, LR.Middle) Call = (2, LR.Left)
Subscript = (2, LR.Left)
Lookup = (2, LR.Left)
BitwiseOr = (17 - 13, LR.Left) Neg = (3, LR.Right)
Not = (3, LR.Right)
AddressOf = (3, LR.Right)
Deref = (3, LR.Right)
Cast = (3, LR.Right)
BitwiseXor = (17 - 12, LR.Left) Mul = (5, LR.Left)
Div = (5, LR.Left)
Rem = (5, LR.Left)
BitwiseAnd = (17 - 11, LR.Left) Add = (6, LR.Left)
Sub = (6, LR.Left)
LeftShift = (17 - 7, LR.Left) LeftShift = (7, LR.Left)
RightShift = (17 - 7, LR.Left) RightShift = (7, LR.Left)
Add = (17 - 6, LR.Left) RelOp = (9, LR.Left)
Sub = (17 - 6, LR.Left)
Mul = (17 - 5, LR.Left) EqOp = (10, LR.Left)
Div = (17 - 5, LR.Left)
Rem = (17 - 5, LR.Left)
Neg = (17 - 3, LR.Right) BitwiseAnd = (11, LR.Left)
AddressOf = (17 - 3, LR.Right)
Deref = (17 - 3, LR.Right)
Cast = (17 - 3, LR.Right)
Call = (17 - 2, LR.Left) BitwiseXor = (12, LR.Left)
Subscript = (17 - 2, LR.Left)
Lookup = (17 - 2, LR.Left) BitwiseOr = (13, LR.Left)
LogicAnd = (14, LR.Left)
LogicOr = (15, LR.Left)
Weakest = (17, LR.Middle)
def __init__(self, pred: int, assoc: LR) -> None: def __init__(self, pred: int, assoc: LR) -> None:
self.precedence = pred self.precedence = pred
...@@ -125,7 +143,7 @@ class PrinterCtx: ...@@ -125,7 +143,7 @@ class PrinterCtx:
return self.branch_stack[-1] return self.branch_stack[-1]
def parenthesize(self, expr: str, next_operator: Ops) -> str: def parenthesize(self, expr: str, next_operator: Ops) -> str:
if next_operator.precedence < self.current_op.precedence: if next_operator.precedence > self.current_op.precedence:
return f"({expr})" return f"({expr})"
elif ( elif (
next_operator.precedence == self.current_op.precedence next_operator.precedence == self.current_op.precedence
...@@ -274,6 +292,13 @@ class CAstPrinter: ...@@ -274,6 +292,13 @@ class CAstPrinter:
return pc.parenthesize(f"-{operand_code}", Ops.Neg) return pc.parenthesize(f"-{operand_code}", Ops.Neg)
case PsNot(operand):
pc.push_op(Ops.Not, LR.Right)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(f"!{operand_code}", Ops.Neg)
case PsDeref(operand): case PsDeref(operand):
pc.push_op(Ops.Deref, LR.Right) pc.push_op(Ops.Deref, LR.Right)
operand_code = self.visit(operand, pc) operand_code = self.visit(operand, pc)
...@@ -339,5 +364,21 @@ class CAstPrinter: ...@@ -339,5 +364,21 @@ class CAstPrinter:
return ("^", Ops.BitwiseXor) return ("^", Ops.BitwiseXor)
case PsBitwiseOr(): case PsBitwiseOr():
return ("|", Ops.BitwiseOr) return ("|", Ops.BitwiseOr)
case PsAnd():
return ("&&", Ops.LogicAnd)
case PsOr():
return ("||", Ops.LogicOr)
case PsEq():
return ("==", Ops.EqOp)
case PsNe():
return ("!=", Ops.EqOp)
case PsGt():
return (">", Ops.RelOp)
case PsGe():
return (">=", Ops.RelOp)
case PsLt():
return ("<", Ops.RelOp)
case PsLe():
return ("<=", Ops.RelOp)
case _: case _:
assert False assert False
...@@ -6,7 +6,7 @@ from pystencils.backend.kernelfunction import KernelFunction ...@@ -6,7 +6,7 @@ from pystencils.backend.kernelfunction import KernelFunction
from pystencils.backend.symbols import PsSymbol from pystencils.backend.symbols import PsSymbol
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer
from pystencils.types.quick import Fp, SInt, UInt from pystencils.types.quick import Fp, SInt, UInt, Bool
from pystencils.backend.emission import CAstPrinter from pystencils.backend.emission import CAstPrinter
...@@ -99,3 +99,61 @@ def test_printing_integer_functions(): ...@@ -99,3 +99,61 @@ def test_printing_integer_functions():
) )
code = cprint(expr) code = cprint(expr)
assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i" assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i"
def test_logical_precedence():
from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr
p, q, r = [PsExpression.make(PsSymbol(x, Bool())) for x in "pqr"]
true = PsExpression.make(PsConstant(True, Bool()))
false = PsExpression.make(PsConstant(False, Bool()))
cprint = CAstPrinter()
expr = PsNot(PsAnd(p, PsOr(q, r)))
code = cprint(expr)
assert code == "!(p && (q || r))"
expr = PsAnd(PsAnd(p, q), PsAnd(q, r))
code = cprint(expr)
assert code == "p && q && (q && r)"
expr = PsOr(PsAnd(true, p), PsOr(PsAnd(false, PsNot(q)), PsAnd(r, p)))
code = cprint(expr)
assert code == "true && p || (false && !q || r && p)"
expr = PsAnd(PsOr(PsNot(p), PsNot(q)), PsNot(PsOr(true, false)))
code = cprint(expr)
assert code == "(!p || !q) && !(true || false)"
def test_relations_precedence():
from pystencils.backend.ast.expressions import (
PsNot,
PsAnd,
PsOr,
PsEq,
PsNe,
PsLt,
PsGt,
PsLe,
PsGe,
)
x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"]
cprint = CAstPrinter()
expr = PsAnd(PsEq(x, y), PsLe(y, z))
code = cprint(expr)
assert code == "x == y && y <= z"
expr = PsOr(PsLt(x, y), PsLt(y, z))
code = cprint(expr)
assert code == "x < y || y < z"
expr = PsAnd(PsNot(PsGe(x, y)), PsNot(PsLe(y, z)))
code = cprint(expr)
assert code == "!(x >= y) && !(y <= z)"
expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z)))
code = cprint(expr)
assert code == "x != y || !(y > z)"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment