Skip to content
Snippets Groups Projects
Commit a4ba27fb authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix peeled loop start index

parent e9ab9e75
No related branches found
No related tags found
1 merge request!376Loop Transformations: Cutting and Peeling
Pipeline #64984 passed
from typing import cast, Iterable
from typing import cast, Iterable, overload
from collections import defaultdict
from ..kernelcreation import KernelCreationContext, Typifier
......@@ -116,6 +116,14 @@ class EliminateConstants:
self._fold_floats = False
self._extract_constant_exprs = extract_constant_exprs
@overload
def __call__(self, node: PsExpression) -> PsExpression:
pass
@overload
def __call__(self, node: PsAstNode) -> PsAstNode:
pass
def __call__(self, node: PsAstNode) -> PsAstNode:
ecc = ECContext(self._ctx)
......
......@@ -64,8 +64,8 @@ class ReshapeLoops:
peeled_iters.append(peeled_block)
loop.start = self._typify(
loop.start + PsExpression.make(PsConstant(num_iterations))
loop.start = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
)
return peeled_iters, loop
......
......@@ -77,7 +77,8 @@ def test_loop_peeling():
loop = factory.loops_from_ispace(ispace, loop_body)
peeled_iters, loop = reshape.peel_loop_front(loop, 3)
num_iters = 3
peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
......@@ -94,3 +95,7 @@ def test_loop_peeling():
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment