Skip to content
Snippets Groups Projects
Commit e9ab9e75 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add loop reshaping test cases

parent 9c5e12ea
No related branches found
No related tags found
1 merge request!376Loop Transformations: Cutting and Peeling
Pipeline #64983 passed
......@@ -121,7 +121,7 @@ class FullIterationSpace(IterationSpace):
@staticmethod
def create_from_slice(
ctx: KernelCreationContext,
iteration_slice: Sequence[slice],
iteration_slice: slice | Sequence[slice],
archetype_field: Field | None = None,
):
"""Create an iteration space from a sequence of slices, optionally over an archetype field.
......@@ -131,6 +131,9 @@ class FullIterationSpace(IterationSpace):
iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice`
archetype_field: Optionally, an archetype field that dictates the upper slice limits and loop order.
"""
if isinstance(iteration_slice, slice):
iteration_slice = (iteration_slice,)
dim = len(iteration_slice)
if dim == 0:
raise ValueError(
......
......@@ -72,7 +72,7 @@ class ReshapeLoops:
def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock | PsConditional]:
) -> Sequence[PsLoop | PsBlock]:
"""Cut a loop at the given cutting points.
Cut the given loop at the iterations specified by the given cutting points,
......@@ -82,6 +82,9 @@ class ReshapeLoops:
Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are
returned without the trivial loop structure.
Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside
the loop's iteration range.
Returns:
Sequence of ``n`` subtrees representing the respective iteration ranges
"""
......@@ -93,7 +96,7 @@ class ReshapeLoops:
"Loop cutting for loops with step != 1 is not implemented"
)
result: list[PsLoop | PsBlock | PsConditional] = []
result: list[PsLoop | PsBlock] = []
new_start = loop.start
cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [
loop.stop
......@@ -103,7 +106,7 @@ class ReshapeLoops:
if new_end.structurally_equal(new_start):
continue
num_iters = self._elim_constants(new_end - new_start)
num_iters = self._elim_constants(self._typify(new_end - new_start))
skip = False
if isinstance(num_iters, PsConstantExpr):
......@@ -113,8 +116,11 @@ class ReshapeLoops:
skip = True
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
local_counter = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
ctr_decl = PsDeclaration(
PsExpression.make(cc.get_replacement(loop.counter.symbol)),
local_counter,
new_start,
)
cloned_body = self._canon_clone.visit(loop.body, cc)
......
import sympy as sp
from pystencils import Field, Assignment, make_slice
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import ReshapeLoops
from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional
from pystencils.backend.ast.expressions import PsConstantExpr, PsLt
def test_loop_cutting():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
subloops = reshape.cut_loop(loop, [1, 1, 3])
assert len(subloops) == 3
subloop = subloops[0]
assert isinstance(subloop, PsBlock)
assert isinstance(subloop.statements[0], PsDeclaration)
assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
x_decl = subloop.statements[1]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__0"
subloop = subloops[1]
assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
x_decl = subloop.body.statements[0]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__1"
subloop = subloops[2]
assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
assert subloop.stop.structurally_equal(loop.stop)
def test_loop_peeling():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock([
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
])
loop = factory.loops_from_ispace(ispace, loop_body)
peeled_iters, loop = reshape.peel_loop_front(loop, 3)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment