Forked from
pycodegen / pystencils
45 commits behind, 162 commits ahead of the upstream repository.
test_reshape_loops.py 5.06 KiB
import sympy as sp
from pystencils import Field, Assignment, make_slice
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import ReshapeLoops
from pystencils.backend.ast.structural import (
PsDeclaration,
PsBlock,
PsLoop,
PsConditional,
)
from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
def test_loop_cutting():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
subloops = reshape.cut_loop(loop, [1, 1, 3])
assert len(subloops) == 3
subloop = subloops[0]
assert isinstance(subloop, PsBlock)
assert isinstance(subloop.statements[0], PsDeclaration)
assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
x_decl = subloop.statements[1]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__0"
subloop = subloops[1]
assert isinstance(subloop, PsLoop)
assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
)
assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
x_decl = subloop.body.statements[0]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__1"
subloop = subloops[2]
assert isinstance(subloop, PsLoop)
assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
)
assert subloop.stop.structurally_equal(loop.stop)
def test_loop_peeling():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(
ctx, slice(2, 11, 3), archetype_field=f
)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 2
peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
assert len(peeled_iters) == num_iters
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
ctr_value = {0: 2, 1: 5}[i]
assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(factory.parse_index(8))
assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body)
def test_loop_peeling_back():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3
peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(loop.start)
assert peeled_loop.stop.structurally_equal(
factory.loops_from_ispace(ispace, loop_body).stop
- factory.parse_index(num_iters)
)
assert peeled_loop.body.structurally_equal(loop.body)