diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py index 146bc07e19edbe2c74ac833530d28b41b999f530..11c4bfd3bf17c1d9a1d3bbcc8562416aa3870593 100644 --- a/src/pystencils/backend/transformations/reshape_loops.py +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -48,7 +48,9 @@ class ReshapeLoops: peeled_ctr = self._factory.parse_index( cc.get_replacement(loop.counter.symbol) ) - peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) + peeled_idx = self._elim_constants( + self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step) + ) counter_decl = PsDeclaration(peeled_ctr, peeled_idx) peeled_block = self._canon_clone.visit(loop.body, cc) @@ -65,7 +67,9 @@ class ReshapeLoops: peeled_iters.append(peeled_block) loop.start = self._elim_constants( - self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) + self._typify( + loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step + ) ) return peeled_iters, loop @@ -88,6 +92,13 @@ class ReshapeLoops: Tuple containing the modified loop and the peeled-off iterations (sequence of blocks). """ + if not ( + isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1 + ): + raise NotImplementedError( + "Peeling iterations from the back of loops is only implemented for loops with unit step. Implementation is deferred until loop range canonicalization is available (also needed for the vectorizer)." + ) + peeled_iters: list[PsBlock] = [] for i in range(num_iterations)[::-1]: diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py index f3a3312cf03ea34a844314d15042287551c76fc9..c52e98ba0fb2269afc4c3b968bcfc599cd400a7d 100644 --- a/tests/nbackend/transformations/test_reshape_loops.py +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -76,7 +76,9 @@ def test_loop_peeling(): x, y, z = sp.symbols("x, y, z") f = Field.create_generic("f", 1, index_shape=(2,)) - ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ispace = FullIterationSpace.create_from_slice( + ctx, slice(2, 11, 3), archetype_field=f + ) ctx.set_iteration_space(ispace) loop_body = PsBlock( @@ -88,9 +90,9 @@ def test_loop_peeling(): loop = factory.loops_from_ispace(ispace, loop_body) - num_iters = 3 + num_iters = 2 peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) - assert len(peeled_iters) == 3 + assert len(peeled_iters) == num_iters for i, iter in enumerate(peeled_iters): assert isinstance(iter, PsBlock) @@ -98,6 +100,8 @@ def test_loop_peeling(): ctr_decl = iter.statements[0] assert isinstance(ctr_decl, PsDeclaration) assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" + ctr_value = {0: 2, 1: 5}[i] + assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value)) cond = iter.statements[1] assert isinstance(cond, PsConditional) @@ -107,7 +111,7 @@ def test_loop_peeling(): assert isinstance(subblock.statements[0], PsDeclaration) assert subblock.statements[0].declared_symbol.name == f"x__{i}" - assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) + assert peeled_loop.start.structurally_equal(factory.parse_index(8)) assert peeled_loop.stop.structurally_equal(loop.stop) assert peeled_loop.body.structurally_equal(loop.body)