From 575322bc5a9dc5563f0f3b109f2dcd5ad1a55e24 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Tue, 1 Oct 2024 09:18:46 +0200
Subject: [PATCH] Improve freezing of additions

---
 .../backend/kernelcreation/freeze.py          | 22 ++++------
 tests/nbackend/kernelcreation/test_freeze.py  | 42 +++++++++++++++++--
 2 files changed, 46 insertions(+), 18 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index b626ffb65..fa936e506 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -191,27 +191,19 @@ class FreezeExpressions:
 
     def map_Add(self, expr: sp.Add) -> PsExpression:
         #   TODO: think about numerically sensible ways of freezing sums and products
-        signs: list[int] = []
-        for summand in expr.args:
-            if summand.is_negative:
-                signs.append(-1)
-            elif isinstance(summand, sp.Mul) and any(
-                factor.is_negative for factor in summand.args
-            ):
-                signs.append(-1)
-            else:
-                signs.append(1)
 
         frozen_expr = self.visit_expr(expr.args[0])
 
-        for sign, arg in zip(signs[1:], expr.args[1:]):
-            if sign == -1:
-                arg = -arg
+        for summand in expr.args[1:]:
+            if isinstance(summand, sp.Mul) and any(
+                factor == -1 for factor in summand.args
+            ):
+                summand = -summand
                 op = sub
             else:
                 op = add
 
-            frozen_expr = op(frozen_expr, self.visit_expr(arg))
+            frozen_expr = op(frozen_expr, self.visit_expr(summand))
 
         return frozen_expr
 
@@ -272,7 +264,7 @@ class FreezeExpressions:
 
     def map_TypedSymbol(self, expr: TypedSymbol):
         dtype = expr.dtype
-        
+
         match dtype:
             case DynamicType.NUMERIC_TYPE:
                 dtype = self._ctx.default_dtype
diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py
index 9467bdd8e..2761d9b3f 100644
--- a/tests/nbackend/kernelcreation/test_freeze.py
+++ b/tests/nbackend/kernelcreation/test_freeze.py
@@ -1,7 +1,14 @@
 import sympy as sp
 import pytest
 
-from pystencils import Assignment, fields, create_type, create_numeric_type, TypedSymbol, DynamicType
+from pystencils import (
+    Assignment,
+    fields,
+    create_type,
+    create_numeric_type,
+    TypedSymbol,
+    DynamicType,
+)
 from pystencils.sympyextensions import CastFunc
 
 from pystencils.backend.ast.structural import (
@@ -30,6 +37,9 @@ from pystencils.backend.ast.expressions import (
     PsCall,
     PsCast,
     PsConstantExpr,
+    PsAdd,
+    PsMul,
+    PsSub,
 )
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.functions import PsMathFunction, MathFunctions
@@ -277,11 +287,11 @@ def test_dynamic_types():
     p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"]
 
     expr = freeze(x + y)
-    
+
     assert ctx.get_symbol("x").dtype == ctx.default_dtype
     assert ctx.get_symbol("y").dtype == ctx.default_dtype
 
-    expr = freeze(p - q)    
+    expr = freeze(p - q)
     assert ctx.get_symbol("p").dtype == ctx.index_dtype
     assert ctx.get_symbol("q").dtype == ctx.index_dtype
 
@@ -309,3 +319,29 @@ def test_cast_func():
 
     expr = freeze(CastFunc(42, create_type("int16")))
     assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16"))))
+
+
+def test_add_sub():
+    ctx = KernelCreationContext()
+    freeze = FreezeExpressions(ctx)
+
+    x = sp.Symbol("x")
+    y = sp.Symbol("y", negative=True)
+
+    x2 = PsExpression.make(ctx.get_symbol("x"))
+    y2 = PsExpression.make(ctx.get_symbol("y"))
+
+    two = PsExpression.make(PsConstant(2))
+    minus_two = PsExpression.make(PsConstant(-2))
+
+    expr = freeze(x + y)
+    assert expr.structurally_equal(PsAdd(x2, y2))
+
+    expr = freeze(x - y)
+    assert expr.structurally_equal(PsSub(x2, y2))
+
+    expr = freeze(x + 2 * y)
+    assert expr.structurally_equal(PsAdd(x2, PsMul(two, y2)))
+
+    expr = freeze(x - 2 * y)
+    assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2)))
-- 
GitLab