From 6be4b5af798082acb82ab61623a0e3c2b8250c3b Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 3 Jul 2024 09:26:51 +0200
Subject: [PATCH] Fix multiarg min/max freeze, introduce test case, fix
 MathFunction equality

---
 src/pystencils/backend/ast/expressions.py     |  4 +++
 src/pystencils/backend/functions.py           | 12 +++++++
 .../backend/kernelcreation/freeze.py          |  2 +-
 tests/nbackend/kernelcreation/test_freeze.py  | 34 ++++++++++++++++++-
 4 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 5e6adfa4f..67fd6b27e 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -419,6 +419,10 @@ class PsCall(PsExpression):
         if not isinstance(other, PsCall):
             return False
         return super().structurally_equal(other) and self._function == other._function
+    
+    def __str__(self):
+        args = ", ".join(str(arg) for arg in self._args)
+        return f"PsCall({self._function}, ({args}))"
 
 
 class PsTernary(PsExpression):
diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index 80ea58d18..30b243d9c 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -146,3 +146,15 @@ class PsMathFunction(PsFunction):
     @property
     def func(self) -> MathFunctions:
         return self._func
+    
+    def __str__(self) -> str:
+        return f"{self._func.function_name}"
+    
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsMathFunction):
+            return False
+        
+        return self._func == other._func
+    
+    def __hash__(self) -> int:
+        return hash(self._func)
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 1e1355bbf..3865db38f 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -460,7 +460,7 @@ class FreezeExpressions:
         while len(args) > 1:
             args = [
                 (PsCall(func, (args[i], args[i + 1])) if i + 1 < len(args) else args[i])
-                for i in range((len(args) + 1) // 2)
+                for i in range(0, len(args), 2)
             ]
         return cast(PsCall, args[0])
 
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index b1e8525b1..b22df7d0b 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -25,9 +25,11 @@ from pystencils.backend.ast.expressions import (
     PsLt,
     PsLe,
     PsGt,
-    PsGe
+    PsGe,
+    PsCall
 )
 from pystencils.backend.constants import PsConstant
+from pystencils.backend.functions import PsMathFunction, MathFunctions
 from pystencils.backend.kernelcreation import (
     KernelCreationContext,
     FreezeExpressions,
@@ -227,3 +229,33 @@ def test_freeze_piecewise():
     piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q)))
     with pytest.raises(FreezeError):
         freeze(piecewise)
+
+
+def test_multiarg_min_max():
+    ctx = KernelCreationContext()
+    freeze = FreezeExpressions(ctx)
+
+    w, x, y, z = sp.symbols("w, x, y, z")
+
+    x2 = PsExpression.make(ctx.get_symbol("x"))
+    y2 = PsExpression.make(ctx.get_symbol("y"))
+    z2 = PsExpression.make(ctx.get_symbol("z"))
+    w2 = PsExpression.make(ctx.get_symbol("w"))
+
+    def op(a, b):
+        return PsCall(PsMathFunction(MathFunctions.Min), (a, b))
+
+    expr = freeze(sp.Min(w, x, y))
+    assert expr.structurally_equal(op(op(w2, x2), y2))
+
+    expr = freeze(sp.Min(w, x, y, z))
+    assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
+
+    def op(a, b):
+        return PsCall(PsMathFunction(MathFunctions.Max), (a, b))
+
+    expr = freeze(sp.Max(w, x, y))
+    assert expr.structurally_equal(op(op(w2, x2), y2))
+
+    expr = freeze(sp.Max(w, x, y, z))
+    assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
-- 
GitLab