From 5f257fe7c90a8252e782a7872bd80b705ea1c7f1 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Tue, 14 Jan 2025 08:36:57 +0100 Subject: [PATCH] fix canonicalization of >2 loops with the same counter --- .../backend/transformations/canonicalize_symbols.py | 3 ++- .../transformations/test_canonicalize_symbols.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index c0406c25d..17875caa1 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -102,10 +102,11 @@ class CanonicalizeSymbols: cc.mark_as_updated(lhs.symbol) case PsLoop(ctr, _, _, _, _): + decl_symb = ctr.symbol for c in node.children[::-1]: self.visit(c, cc) cc.mark_as_updated(ctr.symbol) - cc.end_lifespan(ctr.symbol) + cc.end_lifespan(decl_symb) case PsConditional(cond, then, els): if els is not None: diff --git a/tests/nbackend/transformations/test_canonicalize_symbols.py b/tests/nbackend/transformations/test_canonicalize_symbols.py index a11e9bd13..2758d1234 100644 --- a/tests/nbackend/transformations/test_canonicalize_symbols.py +++ b/tests/nbackend/transformations/test_canonicalize_symbols.py @@ -104,13 +104,16 @@ def test_loop_counters(): loops = factory.loops_from_ispace(ispace, body) - loops_copy = loops.clone() + loops_clone = loops.clone() + loops_clone2 = loops.clone() - ast = PsBlock([loops, loops_copy]) + ast = PsBlock([loops, loops_clone, loops_clone2]) ast = canonicalize(ast) - assert loops_copy.counter.symbol.name == "ctr_0" - assert not loops_copy.counter.symbol.get_dtype().const - assert loops.counter.symbol.name == "ctr_0__0" + assert loops_clone2.counter.symbol.name == "ctr_0" + assert not loops_clone2.counter.symbol.get_dtype().const + assert loops_clone.counter.symbol.name == "ctr_0__0" + assert not loops_clone.counter.symbol.get_dtype().const + assert loops.counter.symbol.name == "ctr_0__1" assert not loops.counter.symbol.get_dtype().const -- GitLab