diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index 7678dbd8c6ce783585fb7095b201e9f92e65e485..7fa4766eb305954f56d10b8cf8052c2fb26cb8fe 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -1,4 +1,4 @@
-from typing import cast, Iterable
+from typing import cast, Iterable, overload
 from collections import defaultdict
 
 from ..kernelcreation import KernelCreationContext, Typifier
@@ -116,6 +116,14 @@ class EliminateConstants:
         self._fold_floats = False
         self._extract_constant_exprs = extract_constant_exprs
 
+    @overload
+    def __call__(self, node: PsExpression) -> PsExpression:
+        pass
+
+    @overload
+    def __call__(self, node: PsAstNode) -> PsAstNode:
+        pass
+
     def __call__(self, node: PsAstNode) -> PsAstNode:
         ecc = ECContext(self._ctx)
 
diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py
index 317586204afe922e5f130b805cb3cbbc10aa62fb..6963bee0b2e43bc6bac58a6c96de5f4a35e57148 100644
--- a/src/pystencils/backend/transformations/reshape_loops.py
+++ b/src/pystencils/backend/transformations/reshape_loops.py
@@ -64,8 +64,8 @@ class ReshapeLoops:
 
             peeled_iters.append(peeled_block)
 
-        loop.start = self._typify(
-            loop.start + PsExpression.make(PsConstant(num_iterations))
+        loop.start = self._elim_constants(
+            self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
         )
 
         return peeled_iters, loop
diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py
index e9c5ff2ee16d484282ccf5a843650fb0f5f3dc6c..e68cff1b64acbb4f9bbf30dee9ef3f2abe9e59d3 100644
--- a/tests/nbackend/transformations/test_reshape_loops.py
+++ b/tests/nbackend/transformations/test_reshape_loops.py
@@ -77,7 +77,8 @@ def test_loop_peeling():
 
     loop = factory.loops_from_ispace(ispace, loop_body)
 
-    peeled_iters, loop = reshape.peel_loop_front(loop, 3)
+    num_iters = 3
+    peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
     assert len(peeled_iters) == 3
 
     for i, iter in enumerate(peeled_iters):
@@ -94,3 +95,7 @@ def test_loop_peeling():
         subblock = cond.branch_true
         assert isinstance(subblock.statements[0], PsDeclaration)
         assert subblock.statements[0].declared_symbol.name == f"x__{i}"
+
+    assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
+    assert peeled_loop.stop.structurally_equal(loop.stop)
+    assert peeled_loop.body.structurally_equal(loop.body)