From a4ba27fb9cd8720945083cb92ec8314d47f3e6da Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 8 Apr 2024 11:05:22 +0200 Subject: [PATCH] Fix peeled loop start index --- .../backend/transformations/eliminate_constants.py | 10 +++++++++- .../backend/transformations/reshape_loops.py | 4 ++-- tests/nbackend/transformations/test_reshape_loops.py | 7 ++++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 7678dbd8c..7fa4766eb 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -1,4 +1,4 @@ -from typing import cast, Iterable +from typing import cast, Iterable, overload from collections import defaultdict from ..kernelcreation import KernelCreationContext, Typifier @@ -116,6 +116,14 @@ class EliminateConstants: self._fold_floats = False self._extract_constant_exprs = extract_constant_exprs + @overload + def __call__(self, node: PsExpression) -> PsExpression: + pass + + @overload + def __call__(self, node: PsAstNode) -> PsAstNode: + pass + def __call__(self, node: PsAstNode) -> PsAstNode: ecc = ECContext(self._ctx) diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py index 317586204..6963bee0b 100644 --- a/src/pystencils/backend/transformations/reshape_loops.py +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -64,8 +64,8 @@ class ReshapeLoops: peeled_iters.append(peeled_block) - loop.start = self._typify( - loop.start + PsExpression.make(PsConstant(num_iterations)) + loop.start = self._elim_constants( + self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) ) return peeled_iters, loop diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py index e9c5ff2ee..e68cff1b6 100644 --- a/tests/nbackend/transformations/test_reshape_loops.py +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -77,7 +77,8 @@ def test_loop_peeling(): loop = factory.loops_from_ispace(ispace, loop_body) - peeled_iters, loop = reshape.peel_loop_front(loop, 3) + num_iters = 3 + peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) assert len(peeled_iters) == 3 for i, iter in enumerate(peeled_iters): @@ -94,3 +95,7 @@ def test_loop_peeling(): subblock = cond.branch_true assert isinstance(subblock.statements[0], PsDeclaration) assert subblock.statements[0].declared_symbol.name == f"x__{i}" + + assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) + assert peeled_loop.stop.structurally_equal(loop.stop) + assert peeled_loop.body.structurally_equal(loop.body) -- GitLab