From a4ba27fb9cd8720945083cb92ec8314d47f3e6da Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 8 Apr 2024 11:05:22 +0200
Subject: [PATCH] Fix peeled loop start index

---
 .../backend/transformations/eliminate_constants.py     | 10 +++++++++-
 .../backend/transformations/reshape_loops.py           |  4 ++--
 tests/nbackend/transformations/test_reshape_loops.py   |  7 ++++++-
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index 7678dbd8c..7fa4766eb 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -1,4 +1,4 @@
-from typing import cast, Iterable
+from typing import cast, Iterable, overload
 from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext, Typifier
@@ -116,6 +116,14 @@ class EliminateConstants:
         self._fold_floats = False
         self._extract_constant_exprs = extract_constant_exprs
 
+    @overload
+    def __call__(self, node: PsExpression) -> PsExpression:
+        pass
+
+    @overload
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        pass
+
     def __call__(self, node: PsAstNode) -> PsAstNode:
         ecc = ECContext(self._ctx)
 
diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py
index 317586204..6963bee0b 100644
--- a/src/pystencils/backend/transformations/reshape_loops.py
+++ b/src/pystencils/backend/transformations/reshape_loops.py
@@ -64,8 +64,8 @@ class ReshapeLoops:
 
             peeled_iters.append(peeled_block)
 
-        loop.start = self._typify(
-            loop.start + PsExpression.make(PsConstant(num_iterations))
+        loop.start = self._elim_constants(
+            self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
         )
 
         return peeled_iters, loop
diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py
index e9c5ff2ee..e68cff1b6 100644
--- a/tests/nbackend/transformations/test_reshape_loops.py
+++ b/tests/nbackend/transformations/test_reshape_loops.py
@@ -77,7 +77,8 @@ def test_loop_peeling():
 
     loop = factory.loops_from_ispace(ispace, loop_body)
 
-    peeled_iters, loop = reshape.peel_loop_front(loop, 3)
+    num_iters = 3
+    peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
     assert len(peeled_iters) == 3
 
     for i, iter in enumerate(peeled_iters):
@@ -94,3 +95,7 @@ def test_loop_peeling():
         subblock = cond.branch_true
         assert isinstance(subblock.statements[0], PsDeclaration)
         assert subblock.statements[0].declared_symbol.name == f"x__{i}"
+
+    assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
+    assert peeled_loop.stop.structurally_equal(loop.stop)
+    assert peeled_loop.body.structurally_equal(loop.body)
-- 
GitLab