diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 7678dbd8c6ce783585fb7095b201e9f92e65e485..7fa4766eb305954f56d10b8cf8052c2fb26cb8fe 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -1,4 +1,4 @@ -from typing import cast, Iterable +from typing import cast, Iterable, overload from collections import defaultdict from ..kernelcreation import KernelCreationContext, Typifier @@ -116,6 +116,14 @@ class EliminateConstants: self._fold_floats = False self._extract_constant_exprs = extract_constant_exprs + @overload + def __call__(self, node: PsExpression) -> PsExpression: + pass + + @overload + def __call__(self, node: PsAstNode) -> PsAstNode: + pass + def __call__(self, node: PsAstNode) -> PsAstNode: ecc = ECContext(self._ctx) diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py index 317586204afe922e5f130b805cb3cbbc10aa62fb..6963bee0b2e43bc6bac58a6c96de5f4a35e57148 100644 --- a/src/pystencils/backend/transformations/reshape_loops.py +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -64,8 +64,8 @@ class ReshapeLoops: peeled_iters.append(peeled_block) - loop.start = self._typify( - loop.start + PsExpression.make(PsConstant(num_iterations)) + loop.start = self._elim_constants( + self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) ) return peeled_iters, loop diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py index e9c5ff2ee16d484282ccf5a843650fb0f5f3dc6c..e68cff1b64acbb4f9bbf30dee9ef3f2abe9e59d3 100644 --- a/tests/nbackend/transformations/test_reshape_loops.py +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -77,7 +77,8 @@ def test_loop_peeling(): loop = factory.loops_from_ispace(ispace, loop_body) - peeled_iters, loop = reshape.peel_loop_front(loop, 3) + num_iters = 3 + peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) assert len(peeled_iters) == 3 for i, iter in enumerate(peeled_iters): @@ -94,3 +95,7 @@ def test_loop_peeling(): subblock = cond.branch_true assert isinstance(subblock.statements[0], PsDeclaration) assert subblock.statements[0].declared_symbol.name == f"x__{i}" + + assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) + assert peeled_loop.stop.structurally_equal(loop.stop) + assert peeled_loop.body.structurally_equal(loop.body)