diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 051dca880b421dfcb896f938ce43483ff7a11bce..83c406b0a99d52cee9599f321d6c32477f6dbf8a 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -5,7 +5,7 @@ import sympy as sp from sympy.codegen.ast import AssignmentBase from ..ast import PsAstNode -from ..ast.expressions import PsExpression, PsSymbolExpr +from ..ast.expressions import PsExpression, PsSymbolExpr, PsConstantExpr from ..ast.structural import PsLoop, PsBlock, PsAssignment from ..symbols import PsSymbol @@ -56,6 +56,20 @@ class AstFactory: """ return self._typify(self._freeze(sp_obj)) + @overload + def parse_index(self, idx: sp.Symbol | PsSymbol | PsSymbolExpr) -> PsSymbolExpr: + pass + + @overload + def parse_index( + self, idx: int | np.integer | PsConstant | PsConstantExpr + ) -> PsConstantExpr: + pass + + @overload + def parse_index(self, idx: sp.Expr | PsExpression) -> PsExpression: + pass + def parse_index(self, idx: IndexParsable): """Parse the given object as an expression with data type `ctx.index_dtype`.""" diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py index b194933df6126b243d81bfe9f7b5854c9da27c0d..07a37a05dc8fe53c1a59a29eb3694289d129e7a8 100644 --- a/src/pystencils/backend/transformations/reshape_loops.py +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -1,12 +1,13 @@ from typing import Sequence -from ..kernelcreation import KernelCreationContext +from ..kernelcreation import KernelCreationContext, Typifier from ..kernelcreation.ast_factory import AstFactory, IndexParsable -from ..ast.structural import PsLoop, PsBlock, PsConditional -from ..ast.expressions import PsConstantExpr +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration +from ..ast.expressions import PsExpression, PsConstantExpr, PsLt +from ..constants import PsConstant -from .canonical_clone import CanonicalClone +from .canonical_clone import CanonicalClone, CloneContext from .eliminate_constants import EliminateConstants @@ -15,10 +16,60 @@ class ReshapeLoops: def __init__(self, ctx: KernelCreationContext) -> None: self._ctx = ctx + self._typify = Typifier(ctx) self._factory = AstFactory(ctx) self._canon_clone = CanonicalClone(ctx) self._elim_constants = EliminateConstants(ctx) + def peel_loop_front( + self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False + ) -> tuple[Sequence[PsBlock], PsLoop]: + """Peel off iterations from the front of a loop. + + Removes `num_iterations` from the front of the given loop and returns them as a sequence of + independent blocks. + + Args: + loop: The loop node from which to peel iterations + num_iterations: The number of iterations to peel off + omit_range_check: If set to `True`, assume that the peeled-off iterations will always + be executed, and omit their enclosing conditional. + + Returns: + (peeled_iters, loop): Tuple containing the peeled-off iterations as a sequence of blocks, + and the remaining loop. + """ + + peeled_iters: list[PsBlock] = [] + + for i in range(num_iterations): + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + peeled_ctr = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) + peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) + + counter_decl = PsDeclaration(peeled_ctr, peeled_idx) + peeled_block = self._canon_clone.visit(loop.body, cc) + + if omit_range_check: + peeled_block.statements = [counter_decl] + peeled_block.statements + else: + iter_condition = PsLt(peeled_ctr, loop.stop) + peeled_block.statements = [ + counter_decl, + PsConditional(iter_condition, PsBlock(peeled_block.statements)), + ] + + peeled_iters.append(peeled_block) + + loop.start = self._typify( + loop.start + PsExpression.make(PsConstant(num_iterations)) + ) + + return peeled_iters, loop + def cut_loop( self, loop: PsLoop, cutting_points: Sequence[IndexParsable] ) -> Sequence[PsLoop | PsBlock | PsConditional]: @@ -60,10 +111,14 @@ class ReshapeLoops: skip = True elif num_iters.constant.value == 1: skip = True - cloned_body = self._canon_clone(loop.body) - raise NotImplementedError( - "TODO: Substitute new_start for loop counter" + cc = CloneContext(self._ctx) + cc.symbol_decl(loop.counter.symbol) + ctr_decl = PsDeclaration( + PsExpression.make(cc.get_replacement(loop.counter.symbol)), + new_start, ) + cloned_body = self._canon_clone.visit(loop.body, cc) + cloned_body.statements = [ctr_decl] + cloned_body.statements result.append(cloned_body) if not skip: