Skip to content
Snippets Groups Projects
Commit 87726113 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

peel_loop_front: support loops with step != 1, peel_loop_back: raise if that happens

parent 3819d5ae
No related branches found
No related tags found
1 merge request!388Implement loop peeling from back
...@@ -48,7 +48,9 @@ class ReshapeLoops: ...@@ -48,7 +48,9 @@ class ReshapeLoops:
peeled_ctr = self._factory.parse_index( peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol) cc.get_replacement(loop.counter.symbol)
) )
peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) peeled_idx = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step)
)
counter_decl = PsDeclaration(peeled_ctr, peeled_idx) counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc) peeled_block = self._canon_clone.visit(loop.body, cc)
...@@ -65,7 +67,9 @@ class ReshapeLoops: ...@@ -65,7 +67,9 @@ class ReshapeLoops:
peeled_iters.append(peeled_block) peeled_iters.append(peeled_block)
loop.start = self._elim_constants( loop.start = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) self._typify(
loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step
)
) )
return peeled_iters, loop return peeled_iters, loop
...@@ -88,6 +92,13 @@ class ReshapeLoops: ...@@ -88,6 +92,13 @@ class ReshapeLoops:
Tuple containing the modified loop and the peeled-off iterations (sequence of blocks). Tuple containing the modified loop and the peeled-off iterations (sequence of blocks).
""" """
if not (
isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
):
raise NotImplementedError(
"Peeling iterations from the back of loops is only implemented for loops with unit step. Implementation is deferred until loop range canonicalization is available (also needed for the vectorizer)."
)
peeled_iters: list[PsBlock] = [] peeled_iters: list[PsBlock] = []
for i in range(num_iterations)[::-1]: for i in range(num_iterations)[::-1]:
......
...@@ -76,7 +76,9 @@ def test_loop_peeling(): ...@@ -76,7 +76,9 @@ def test_loop_peeling():
x, y, z = sp.symbols("x, y, z") x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,)) f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) ispace = FullIterationSpace.create_from_slice(
ctx, slice(2, 11, 3), archetype_field=f
)
ctx.set_iteration_space(ispace) ctx.set_iteration_space(ispace)
loop_body = PsBlock( loop_body = PsBlock(
...@@ -88,9 +90,9 @@ def test_loop_peeling(): ...@@ -88,9 +90,9 @@ def test_loop_peeling():
loop = factory.loops_from_ispace(ispace, loop_body) loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3 num_iters = 2
peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
assert len(peeled_iters) == 3 assert len(peeled_iters) == num_iters
for i, iter in enumerate(peeled_iters): for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock) assert isinstance(iter, PsBlock)
...@@ -98,6 +100,8 @@ def test_loop_peeling(): ...@@ -98,6 +100,8 @@ def test_loop_peeling():
ctr_decl = iter.statements[0] ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration) assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
ctr_value = {0: 2, 1: 5}[i]
assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
cond = iter.statements[1] cond = iter.statements[1]
assert isinstance(cond, PsConditional) assert isinstance(cond, PsConditional)
...@@ -107,7 +111,7 @@ def test_loop_peeling(): ...@@ -107,7 +111,7 @@ def test_loop_peeling():
assert isinstance(subblock.statements[0], PsDeclaration) assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}" assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) assert peeled_loop.start.structurally_equal(factory.parse_index(8))
assert peeled_loop.stop.structurally_equal(loop.stop) assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body) assert peeled_loop.body.structurally_equal(loop.body)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment