From 1793501142cfd38cae17b46278458ed118983b55 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 7 Apr 2024 16:29:48 +0200
Subject: [PATCH] Add loop peeling to `ReshapeLoops`

---
 .../backend/kernelcreation/ast_factory.py     | 16 ++++-
 .../backend/transformations/reshape_loops.py  | 69 +++++++++++++++++--
 2 files changed, 77 insertions(+), 8 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py
index 051dca880..83c406b0a 100644
--- a/src/pystencils/backend/kernelcreation/ast_factory.py
+++ b/src/pystencils/backend/kernelcreation/ast_factory.py
@@ -5,7 +5,7 @@ import sympy as sp
 from sympy.codegen.ast import AssignmentBase
 
 from ..ast import PsAstNode
-from ..ast.expressions import PsExpression, PsSymbolExpr
+from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr
 from ..ast.structural import PsLoop, PsBlock, PsAssignment
 
 from ..symbols import PsSymbol
@@ -56,6 +56,20 @@ class AstFactory:
         """
         return self._typify(self._freeze(sp_obj))
 
+    @overload
+    def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr:
+        pass
+
+    @overload
+    def parse_index(
+        self, idx: int | np.integer | PsConstant | PsConstantExpr
+    ) -> PsConstantExpr:
+        pass
+
+    @overload
+    def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression:
+        pass
+
     def parse_index(self, idx: IndexParsable):
         """Parse the given object as an expression with data type `ctx.index_dtype`."""
 
diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py
index b194933df..07a37a05d 100644
--- a/src/pystencils/backend/transformations/reshape_loops.py
+++ b/src/pystencils/backend/transformations/reshape_loops.py
@@ -1,12 +1,13 @@
 from typing import Sequence
 
-from ..kernelcreation import KernelCreationContext
+from ..kernelcreation import KernelCreationContext, Typifier
 from ..kernelcreation.ast_factory import AstFactory, IndexParsable
 
-from ..ast.structural import PsLoop, PsBlock, PsConditional
-from ..ast.expressions import PsConstantExpr
+from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
+from ..ast.expressions import PsExpression, PsConstantExpr, PsLt
+from ..constants import PsConstant
 
-from .canonical_clone import CanonicalClone
+from .canonical_clone import CanonicalClone, CloneContext
 from .eliminate_constants import EliminateConstants
 
 
@@ -15,10 +16,60 @@ class ReshapeLoops:
 
     def __init__(self, ctx: KernelCreationContext) -> None:
         self._ctx = ctx
+        self._typify = Typifier(ctx)
         self._factory = AstFactory(ctx)
         self._canon_clone = CanonicalClone(ctx)
         self._elim_constants = EliminateConstants(ctx)
 
+    def peel_loop_front(
+        self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
+    ) -> tuple[Sequence[PsBlock], PsLoop]:
+        """Peel off iterations from the front of a loop.
+
+        Removes `num_iterations` from the front of the given loop and returns them as a sequence of
+        independent blocks.
+
+        Args:
+            loop: The loop node from which to peel iterations
+            num_iterations: The number of iterations to peel off
+            omit_range_check: If set to `True`, assume that the peeled-off iterations will always
+              be executed, and omit their enclosing conditional.
+
+        Returns:
+            (peeled_iters, loop): Tuple containing the peeled-off iterations as a sequence of blocks,
+            and the remaining loop.
+        """
+
+        peeled_iters: list[PsBlock] = []
+
+        for i in range(num_iterations):
+            cc = CloneContext(self._ctx)
+            cc.symbol_decl(loop.counter.symbol)
+            peeled_ctr = self._factory.parse_index(
+                cc.get_replacement(loop.counter.symbol)
+            )
+            peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i)))
+
+            counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
+            peeled_block = self._canon_clone.visit(loop.body, cc)
+
+            if omit_range_check:
+                peeled_block.statements = [counter_decl] + peeled_block.statements
+            else:
+                iter_condition = PsLt(peeled_ctr, loop.stop)
+                peeled_block.statements = [
+                    counter_decl,
+                    PsConditional(iter_condition, PsBlock(peeled_block.statements)),
+                ]
+
+            peeled_iters.append(peeled_block)
+
+        loop.start = self._typify(
+            loop.start + PsExpression.make(PsConstant(num_iterations))
+        )
+
+        return peeled_iters, loop
+
     def cut_loop(
         self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
     ) -> Sequence[PsLoop | PsBlock | PsConditional]:
@@ -60,10 +111,14 @@ class ReshapeLoops:
                     skip = True
                 elif num_iters.constant.value == 1:
                     skip = True
-                    cloned_body = self._canon_clone(loop.body)
-                    raise NotImplementedError(
-                        "TODO: Substitute new_start for loop counter"
+                    cc = CloneContext(self._ctx)
+                    cc.symbol_decl(loop.counter.symbol)
+                    ctr_decl = PsDeclaration(
+                        PsExpression.make(cc.get_replacement(loop.counter.symbol)),
+                        new_start,
                     )
+                    cloned_body = self._canon_clone.visit(loop.body, cc)
+                    cloned_body.statements = [ctr_decl] + cloned_body.statements
                     result.append(cloned_body)
 
             if not skip:
-- 
GitLab