Skip to content
Snippets Groups Projects

Implement loop peeling from back

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/peel-loop-back into backend-rework
All threads resolved!
Compare and
2 files
+ 137
16
Preferences
Compare changes
Files
2
@@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier
@@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier
from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
from ..ast.expressions import PsExpression, PsConstantExpr, PsGe, PsLt
from ..constants import PsConstant
from ..constants import PsConstant
from .canonical_clone import CanonicalClone, CloneContext
from .canonical_clone import CanonicalClone, CloneContext
@@ -48,7 +48,9 @@ class ReshapeLoops:
@@ -48,7 +48,9 @@ class ReshapeLoops:
peeled_ctr = self._factory.parse_index(
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
cc.get_replacement(loop.counter.symbol)
)
)
peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i)))
peeled_idx = self._elim_constants(
 
self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step)
 
)
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
peeled_block = self._canon_clone.visit(loop.body, cc)
@@ -65,11 +67,68 @@ class ReshapeLoops:
@@ -65,11 +67,68 @@ class ReshapeLoops:
peeled_iters.append(peeled_block)
peeled_iters.append(peeled_block)
loop.start = self._elim_constants(
loop.start = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(num_iterations)))
self._typify(
 
loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step
 
)
)
)
return peeled_iters, loop
return peeled_iters, loop
 
def peel_loop_back(
 
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
 
) -> tuple[PsLoop, Sequence[PsBlock]]:
 
"""Peel off iterations from the back of a loop.
 
 
Removes ``num_iterations`` from the back of the given loop and returns them as a sequence of
 
independent blocks.
 
 
Args:
 
loop: The loop node from which to peel iterations
 
num_iterations: The number of iterations to peel off
 
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
 
be executed, and omit their enclosing conditional.
 
 
Returns:
 
Tuple containing the modified loop and the peeled-off iterations (sequence of blocks).
 
"""
 
 
if not (
 
isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
 
):
 
raise NotImplementedError(
 
"Peeling iterations from the back of loops is only implemented for loops with unit step. Implementation is deferred until loop range canonicalization is available (also needed for the vectorizer)."
 
)
 
 
peeled_iters: list[PsBlock] = []
 
 
for i in range(num_iterations)[::-1]:
 
cc = CloneContext(self._ctx)
 
cc.symbol_decl(loop.counter.symbol)
 
peeled_ctr = self._factory.parse_index(
 
cc.get_replacement(loop.counter.symbol)
 
)
 
peeled_idx = self._typify(loop.stop - PsExpression.make(PsConstant(i + 1)))
 
 
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
 
peeled_block = self._canon_clone.visit(loop.body, cc)
 
 
if omit_range_check:
 
peeled_block.statements = [counter_decl] + peeled_block.statements
 
else:
 
iter_condition = PsGe(peeled_ctr, loop.start)
 
peeled_block.statements = [
 
counter_decl,
 
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
 
]
 
 
peeled_iters.append(peeled_block)
 
 
loop.stop = self._elim_constants(
 
self._typify(loop.stop - PsExpression.make(PsConstant(num_iterations)))
 
)
 
 
return loop, peeled_iters
 
def cut_loop(
def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock]:
) -> Sequence[PsLoop | PsBlock]: