Skip to content
Snippets Groups Projects
Commit 3819d5ae authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

implement loop peeling from back

parent 90239d2c
No related branches found
No related tags found
1 merge request!388Implement loop peeling from back
......@@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier
from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
from ..ast.expressions import PsExpression, PsConstantExpr, PsGe, PsLt
from ..constants import PsConstant
from .canonical_clone import CanonicalClone, CloneContext
......@@ -70,6 +70,54 @@ class ReshapeLoops:
return peeled_iters, loop
def peel_loop_back(
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
) -> tuple[PsLoop, Sequence[PsBlock]]:
"""Peel off iterations from the back of a loop.
Removes ``num_iterations`` from the back of the given loop and returns them as a sequence of
independent blocks.
Args:
loop: The loop node from which to peel iterations
num_iterations: The number of iterations to peel off
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
be executed, and omit their enclosing conditional.
Returns:
Tuple containing the modified loop and the peeled-off iterations (sequence of blocks).
"""
peeled_iters: list[PsBlock] = []
for i in range(num_iterations)[::-1]:
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
peeled_idx = self._typify(loop.stop - PsExpression.make(PsConstant(i + 1)))
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
if omit_range_check:
peeled_block.statements = [counter_decl] + peeled_block.statements
else:
iter_condition = PsGe(peeled_ctr, loop.start)
peeled_block.statements = [
counter_decl,
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
]
peeled_iters.append(peeled_block)
loop.stop = self._elim_constants(
self._typify(loop.stop - PsExpression.make(PsConstant(num_iterations)))
)
return loop, peeled_iters
def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock]:
......
......@@ -8,8 +8,13 @@ from pystencils.backend.kernelcreation import (
)
from pystencils.backend.transformations import ReshapeLoops
from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional
from pystencils.backend.ast.expressions import PsConstantExpr, PsLt
from pystencils.backend.ast.structural import (
PsDeclaration,
PsBlock,
PsLoop,
PsConditional,
)
from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
def test_loop_cutting():
......@@ -43,10 +48,12 @@ def test_loop_cutting():
x_decl = subloop.statements[1]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__0"
subloop = subloops[1]
assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
)
assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
x_decl = subloop.body.statements[0]
......@@ -55,7 +62,9 @@ def test_loop_cutting():
subloop = subloops[2]
assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
)
assert subloop.stop.structurally_equal(loop.stop)
......@@ -70,10 +79,12 @@ def test_loop_peeling():
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock([
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
])
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
......@@ -99,3 +110,50 @@ def test_loop_peeling():
assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters))
assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body)
def test_loop_peeling_back():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3
peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(loop.start)
assert peeled_loop.stop.structurally_equal(
factory.loops_from_ispace(ispace, loop_body).stop
- factory.parse_index(num_iters)
)
assert peeled_loop.body.structurally_equal(loop.body)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment