Skip to content
Snippets Groups Projects
Commit 17935011 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add loop peeling to `ReshapeLoops`

parent dd58f40b
No related branches found
No related tags found
1 merge request!376Loop Transformations: Cutting and Peeling
......@@ -5,7 +5,7 @@ import sympy as sp
from sympy.codegen.ast import AssignmentBase
from ..ast import PsAstNode
from ..ast.expressions import PsExpression, PsSymbolExpr
from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr
from ..ast.structural import PsLoop, PsBlock, PsAssignment
from ..symbols import PsSymbol
......@@ -56,6 +56,20 @@ class AstFactory:
"""
return self._typify(self._freeze(sp_obj))
@overload
def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr:
pass
@overload
def parse_index(
self, idx: int | np.integer | PsConstant | PsConstantExpr
) -> PsConstantExpr:
pass
@overload
def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression:
pass
def parse_index(self, idx: IndexParsable):
"""Parse the given object as an expression with data type `ctx.index_dtype`."""
......
from typing import Sequence
from ..kernelcreation import KernelCreationContext
from ..kernelcreation import KernelCreationContext, Typifier
from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import PsConstantExpr
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
from ..constants import PsConstant
from .canonical_clone import CanonicalClone
from .canonical_clone import CanonicalClone, CloneContext
from .eliminate_constants import EliminateConstants
......@@ -15,10 +16,60 @@ class ReshapeLoops:
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._typify = Typifier(ctx)
self._factory = AstFactory(ctx)
self._canon_clone = CanonicalClone(ctx)
self._elim_constants = EliminateConstants(ctx)
def peel_loop_front(
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
) -> tuple[Sequence[PsBlock], PsLoop]:
"""Peel off iterations from the front of a loop.
Removes `num_iterations` from the front of the given loop and returns them as a sequence of
independent blocks.
Args:
loop: The loop node from which to peel iterations
num_iterations: The number of iterations to peel off
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
be executed, and omit their enclosing conditional.
Returns:
(peeled_iters, loop): Tuple containing the peeled-off iterations as a sequence of blocks,
and the remaining loop.
"""
peeled_iters: list[PsBlock] = []
for i in range(num_iterations):
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i)))
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
if omit_range_check:
peeled_block.statements = [counter_decl] + peeled_block.statements
else:
iter_condition = PsLt(peeled_ctr, loop.stop)
peeled_block.statements = [
counter_decl,
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
]
peeled_iters.append(peeled_block)
loop.start = self._typify(
loop.start + PsExpression.make(PsConstant(num_iterations))
)
return peeled_iters, loop
def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock | PsConditional]:
......@@ -60,10 +111,14 @@ class ReshapeLoops:
skip = True
elif num_iters.constant.value == 1:
skip = True
cloned_body = self._canon_clone(loop.body)
raise NotImplementedError(
"TODO: Substitute new_start for loop counter"
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
ctr_decl = PsDeclaration(
PsExpression.make(cc.get_replacement(loop.counter.symbol)),
new_start,
)
cloned_body = self._canon_clone.visit(loop.body, cc)
cloned_body.statements = [ctr_decl] + cloned_body.statements
result.append(cloned_body)
if not skip:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment