From a333aa55eeb85b83e1fe37f0fdc2d0773500b29c Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 30 Jun 2024 13:41:33 +0200
Subject: [PATCH] integer division and remainder folding tests

---
 .../test_constant_elimination.py              | 45 ++++++++++++++++++-
 1 file changed, 44 insertions(+), 1 deletion(-)

diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py
index dc5a387e8..92bb5c947 100644
--- a/tests/nbackend/transformations/test_constant_elimination.py
+++ b/tests/nbackend/transformations/test_constant_elimination.py
@@ -11,7 +11,8 @@ from pystencils.backend.ast.expressions import (
     PsEq,
     PsGt,
     PsTernary,
-    PsRem
+    PsRem,
+    PsIntDiv
 )
 
 from pystencils.types.quick import Int, Fp, Bool
@@ -28,8 +29,10 @@ f1 = PsExpression.make(PsConstant(1.0, Fp(32)))
 
 i0 = PsExpression.make(PsConstant(0, Int(32)))
 i1 = PsExpression.make(PsConstant(1, Int(32)))
+im1 = PsExpression.make(PsConstant(-1, Int(32)))
 
 i3 = PsExpression.make(PsConstant(3, Int(32)))
+i4 = PsExpression.make(PsConstant(4, Int(32)))
 im3 = PsExpression.make(PsConstant(-3, Int(32)))
 i12 = PsExpression.make(PsConstant(12, Int(32)))
 
@@ -105,6 +108,46 @@ def test_divisions():
     result = elim(expr)
     assert result.structurally_equal(i0)
 
+    expr = typify(PsIntDiv(i12, i3))
+    result = elim(expr)
+    assert result.structurally_equal(i4)
+
+    expr = typify(i12 / i3)
+    result = elim(expr)
+    assert result.structurally_equal(i4)
+
+    expr = typify(PsIntDiv(i4, i3))
+    result = elim(expr)
+    assert result.structurally_equal(i1)
+
+    expr = typify(PsIntDiv(-i4, i3))
+    result = elim(expr)
+    assert result.structurally_equal(im1)
+
+    expr = typify(PsIntDiv(i4, -i3))
+    result = elim(expr)
+    assert result.structurally_equal(im1)
+
+    expr = typify(PsIntDiv(-i4, -i3))
+    result = elim(expr)
+    assert result.structurally_equal(i1)
+
+    expr = typify(PsRem(i4, i3))
+    result = elim(expr)
+    assert result.structurally_equal(i1)
+
+    expr = typify(PsRem(-i4, i3))
+    result = elim(expr)
+    assert result.structurally_equal(im1)
+
+    expr = typify(PsRem(i4, -i3))
+    result = elim(expr)
+    assert result.structurally_equal(i1)
+
+    expr = typify(PsRem(-i4, -i3))
+    result = elim(expr)
+    assert result.structurally_equal(im1)
+
 
 def test_boolean_folding():
     ctx = KernelCreationContext()
-- 
GitLab