diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index dc5a387e8ecb70679b4c125bcb57dd1d7f48e0f6..92bb5c947b4bc2e4b6d50064a9c07874cdda43cf 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -11,7 +11,8 @@ from pystencils.backend.ast.expressions import ( PsEq, PsGt, PsTernary, - PsRem + PsRem, + PsIntDiv ) from pystencils.types.quick import Int, Fp, Bool @@ -28,8 +29,10 @@ f1 = PsExpression.make(PsConstant(1.0, Fp(32))) i0 = PsExpression.make(PsConstant(0, Int(32))) i1 = PsExpression.make(PsConstant(1, Int(32))) +im1 = PsExpression.make(PsConstant(-1, Int(32))) i3 = PsExpression.make(PsConstant(3, Int(32))) +i4 = PsExpression.make(PsConstant(4, Int(32))) im3 = PsExpression.make(PsConstant(-3, Int(32))) i12 = PsExpression.make(PsConstant(12, Int(32))) @@ -105,6 +108,46 @@ def test_divisions(): result = elim(expr) assert result.structurally_equal(i0) + expr = typify(PsIntDiv(i12, i3)) + result = elim(expr) + assert result.structurally_equal(i4) + + expr = typify(i12 / i3) + result = elim(expr) + assert result.structurally_equal(i4) + + expr = typify(PsIntDiv(i4, i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsIntDiv(-i4, i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + expr = typify(PsIntDiv(i4, -i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + expr = typify(PsIntDiv(-i4, -i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsRem(i4, i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsRem(-i4, i3)) + result = elim(expr) + assert result.structurally_equal(im1) + + expr = typify(PsRem(i4, -i3)) + result = elim(expr) + assert result.structurally_equal(i1) + + expr = typify(PsRem(-i4, -i3)) + result = elim(expr) + assert result.structurally_equal(im1) + def test_boolean_folding(): ctx = KernelCreationContext()