diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index ba215f822ea7372211bf764425d44e44487cc46b..6adac2a519ffc04505c0e0adac3484d78f30d013 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace): @staticmethod def create_from_slice( ctx: KernelCreationContext, - iteration_slice: Sequence[slice], + iteration_slice: slice | Sequence[slice], archetype_field: Field | None = None, ): """Create an iteration space from a sequence of slices, optionally over an archetype field. @@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace): iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order. """ + if isinstance(iteration_slice, slice): + iteration_slice = (iteration_slice,) + dim = len(iteration_slice) if dim == 0: raise ValueError( diff --git a/src/pystencils/backend/transformations/reshape_loops.py b/src/pystencils/backend/transformations/reshape_loops.py index ea04cdff55136c067536b66a60d0c6485fde47c2..317586204afe922e5f130b805cb3cbbc10aa62fb 100644 --- a/src/pystencils/backend/transformations/reshape_loops.py +++ b/src/pystencils/backend/transformations/reshape_loops.py @@ -72,7 +72,7 @@ class ReshapeLoops: def cut_loop( self, loop: PsLoop, cutting_points: Sequence[IndexParsable] - ) -> Sequence[PsLoop | PsBlock | PsConditional]: + ) -> Sequence[PsLoop | PsBlock]: """Cut a loop at the given cutting points. Cut the given loop at the iterations specified by the given cutting points, @@ -82,6 +82,9 @@ class ReshapeLoops: Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are returned without the trivial loop structure. + Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside + the loop's iteration range. + Returns: Sequence of ``n`` subtrees representing the respective iteration ranges """ @@ -93,7 +96,7 @@ class ReshapeLoops: "Loop cutting for loops with step != 1 is not implemented" ) - result: list[PsLoop | PsBlock | PsConditional] = [] + result: list[PsLoop | PsBlock] = [] new_start = loop.start cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [ loop.stop @@ -103,7 +106,7 @@ class ReshapeLoops: if new_end.structurally_equal(new_start): continue - num_iters = self._elim_constants(new_end - new_start) + num_iters = self._elim_constants(self._typify(new_end - new_start)) skip = False if isinstance(num_iters, PsConstantExpr): @@ -113,8 +116,11 @@ class ReshapeLoops: skip = True cc = CloneContext(self._ctx) cc.symbol_decl(loop.counter.symbol) + local_counter = self._factory.parse_index( + cc.get_replacement(loop.counter.symbol) + ) ctr_decl = PsDeclaration( - PsExpression.make(cc.get_replacement(loop.counter.symbol)), + local_counter, new_start, ) cloned_body = self._canon_clone.visit(loop.body, cc) diff --git a/tests/nbackend/transformations/test_reshape_loops.py b/tests/nbackend/transformations/test_reshape_loops.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c5ff2ee16d484282ccf5a843650fb0f5f3dc6c --- /dev/null +++ b/tests/nbackend/transformations/test_reshape_loops.py @@ -0,0 +1,96 @@ +import sympy as sp + +from pystencils import Field, Assignment, make_slice +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) +from pystencils.backend.transformations import ReshapeLoops + +from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional +from pystencils.backend.ast.expressions import PsConstantExpr, PsLt + + +def test_loop_cutting(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock( + [ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ] + ) + + loop = factory.loops_from_ispace(ispace, loop_body) + + subloops = reshape.cut_loop(loop, [1, 1, 3]) + assert len(subloops) == 3 + + subloop = subloops[0] + assert isinstance(subloop, PsBlock) + assert isinstance(subloop.statements[0], PsDeclaration) + assert subloop.statements[0].declared_symbol.name == "ctr_0__0" + + x_decl = subloop.statements[1] + assert isinstance(x_decl, PsDeclaration) + assert x_decl.declared_symbol.name == "x__0" + + subloop = subloops[1] + assert isinstance(subloop, PsLoop) + assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1 + assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3 + + x_decl = subloop.body.statements[0] + assert isinstance(x_decl, PsDeclaration) + assert x_decl.declared_symbol.name == "x__1" + + subloop = subloops[2] + assert isinstance(subloop, PsLoop) + assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3 + assert subloop.stop.structurally_equal(loop.stop) + + +def test_loop_peeling(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + reshape = ReshapeLoops(ctx) + + x, y, z = sp.symbols("x, y, z") + + f = Field.create_generic("f", 1, index_shape=(2,)) + ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) + ctx.set_iteration_space(ispace) + + loop_body = PsBlock([ + factory.parse_sympy(Assignment(x, 2 * z)), + factory.parse_sympy(Assignment(f.center(0), x + y)), + ]) + + loop = factory.loops_from_ispace(ispace, loop_body) + + peeled_iters, loop = reshape.peel_loop_front(loop, 3) + assert len(peeled_iters) == 3 + + for i, iter in enumerate(peeled_iters): + assert isinstance(iter, PsBlock) + + ctr_decl = iter.statements[0] + assert isinstance(ctr_decl, PsDeclaration) + assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" + + cond = iter.statements[1] + assert isinstance(cond, PsConditional) + assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop)) + + subblock = cond.branch_true + assert isinstance(subblock.statements[0], PsDeclaration) + assert subblock.statements[0].declared_symbol.name == f"x__{i}"