From 5f257fe7c90a8252e782a7872bd80b705ea1c7f1 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Tue, 14 Jan 2025 08:36:57 +0100
Subject: [PATCH] fix canonicalization of >2 loops with the same counter

---
 .../backend/transformations/canonicalize_symbols.py |  3 ++-
 .../transformations/test_canonicalize_symbols.py    | 13 ++++++++-----
 2 files changed, 10 insertions(+), 6 deletions(-)

diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py
index c0406c25d..17875caa1 100644
--- a/src/pystencils/backend/transformations/canonicalize_symbols.py
+++ b/src/pystencils/backend/transformations/canonicalize_symbols.py
@@ -102,10 +102,11 @@ class CanonicalizeSymbols:
                     cc.mark_as_updated(lhs.symbol)
 
             case PsLoop(ctr, _, _, _, _):
+                decl_symb = ctr.symbol
                 for c in node.children[::-1]:
                     self.visit(c, cc)
                 cc.mark_as_updated(ctr.symbol)
-                cc.end_lifespan(ctr.symbol)
+                cc.end_lifespan(decl_symb)
 
             case PsConditional(cond, then, els):
                 if els is not None:
diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py
index a11e9bd13..2758d1234 100644
--- a/tests/nbackend/transformations/test_canonicalize_symbols.py
+++ b/tests/nbackend/transformations/test_canonicalize_symbols.py
@@ -104,13 +104,16 @@ def test_loop_counters():
 
     loops = factory.loops_from_ispace(ispace, body)
 
-    loops_copy = loops.clone()
+    loops_clone = loops.clone()
+    loops_clone2 = loops.clone()
 
-    ast = PsBlock([loops, loops_copy])
+    ast = PsBlock([loops, loops_clone, loops_clone2])
 
     ast = canonicalize(ast)
 
-    assert loops_copy.counter.symbol.name == "ctr_0"
-    assert not loops_copy.counter.symbol.get_dtype().const
-    assert loops.counter.symbol.name == "ctr_0__0"
+    assert loops_clone2.counter.symbol.name == "ctr_0"
+    assert not loops_clone2.counter.symbol.get_dtype().const
+    assert loops_clone.counter.symbol.name == "ctr_0__0"
+    assert not loops_clone.counter.symbol.get_dtype().const
+    assert loops.counter.symbol.name == "ctr_0__1"
     assert not loops.counter.symbol.get_dtype().const
-- 
GitLab