Skip to content
Snippets Groups Projects
Select Git revision
  • be262db4d066f65e461d4a39c0b6de6fc2ca8f54
  • master default protected
  • develop
  • UpSympy
  • struct
  • FixMod
  • testing
  • VecSparse
  • dev
  • FixDoc
  • SmallFixes
  • UpgradeSympy
  • Revision3
  • TypeSystem
  • BenchMarkCUDA
  • SmallPolish
  • Testing
  • testConda
  • int64
  • assembly
  • warnings
  • release/1.3.6
  • release/2.0.dev0
  • release/1.3.5
  • release/1.3.4
  • release/1.3.3
  • release/1.3.2
  • release/1.3.1
  • release/1.3
  • release/1.2
  • release/1.1.1
  • release/1.1
  • release/1.0.1
  • release/1.0
  • release/0.4.4
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3
  • release/0.4.2
  • release/0.4.1
41 results

setup.py

Blame
  • Forked from pycodegen / pystencils
    Source project has a limited visibility.
    test_reshape_loops.py 5.06 KiB
    import sympy as sp
    
    from pystencils import Field, Assignment, make_slice
    from pystencils.backend.kernelcreation import (
        KernelCreationContext,
        AstFactory,
        FullIterationSpace,
    )
    from pystencils.backend.transformations import ReshapeLoops
    
    from pystencils.backend.ast.structural import (
        PsDeclaration,
        PsBlock,
        PsLoop,
        PsConditional,
    )
    from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
    
    
    def test_loop_cutting():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
        reshape = ReshapeLoops(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        f = Field.create_generic("f", 1, index_shape=(2,))
        ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
        ctx.set_iteration_space(ispace)
    
        loop_body = PsBlock(
            [
                factory.parse_sympy(Assignment(x, 2 * z)),
                factory.parse_sympy(Assignment(f.center(0), x + y)),
            ]
        )
    
        loop = factory.loops_from_ispace(ispace, loop_body)
    
        subloops = reshape.cut_loop(loop, [1, 1, 3])
        assert len(subloops) == 3
    
        subloop = subloops[0]
        assert isinstance(subloop, PsBlock)
        assert isinstance(subloop.statements[0], PsDeclaration)
        assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
    
        x_decl = subloop.statements[1]
        assert isinstance(x_decl, PsDeclaration)
        assert x_decl.declared_symbol.name == "x__0"
    
        subloop = subloops[1]
        assert isinstance(subloop, PsLoop)
        assert (
            isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
        )
        assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
    
        x_decl = subloop.body.statements[0]
        assert isinstance(x_decl, PsDeclaration)
        assert x_decl.declared_symbol.name == "x__1"
    
        subloop = subloops[2]
        assert isinstance(subloop, PsLoop)
        assert (
            isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
        )
        assert subloop.stop.structurally_equal(loop.stop)
    
    
    def test_loop_peeling():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
        reshape = ReshapeLoops(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        f = Field.create_generic("f", 1, index_shape=(2,))
        ispace = FullIterationSpace.create_from_slice(
            ctx, slice(2, 11, 3), archetype_field=f
        )
        ctx.set_iteration_space(ispace)
    
        loop_body = PsBlock(
            [
                factory.parse_sympy(Assignment(x, 2 * z)),
                factory.parse_sympy(Assignment(f.center(0), x + y)),
            ]
        )
    
        loop = factory.loops_from_ispace(ispace, loop_body)
    
        num_iters = 2
        peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
        assert len(peeled_iters) == num_iters
    
        for i, iter in enumerate(peeled_iters):
            assert isinstance(iter, PsBlock)
    
            ctr_decl = iter.statements[0]
            assert isinstance(ctr_decl, PsDeclaration)
            assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
            ctr_value = {0: 2, 1: 5}[i]
            assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
    
            cond = iter.statements[1]
            assert isinstance(cond, PsConditional)
            assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
    
            subblock = cond.branch_true
            assert isinstance(subblock.statements[0], PsDeclaration)
            assert subblock.statements[0].declared_symbol.name == f"x__{i}"
    
        assert peeled_loop.start.structurally_equal(factory.parse_index(8))
        assert peeled_loop.stop.structurally_equal(loop.stop)
        assert peeled_loop.body.structurally_equal(loop.body)
    
    
    def test_loop_peeling_back():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
        reshape = ReshapeLoops(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        f = Field.create_generic("f", 1, index_shape=(2,))
        ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
        ctx.set_iteration_space(ispace)
    
        loop_body = PsBlock(
            [
                factory.parse_sympy(Assignment(x, 2 * z)),
                factory.parse_sympy(Assignment(f.center(0), x + y)),
            ]
        )
    
        loop = factory.loops_from_ispace(ispace, loop_body)
    
        num_iters = 3
        peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
        assert len(peeled_iters) == 3
    
        for i, iter in enumerate(peeled_iters):
            assert isinstance(iter, PsBlock)
    
            ctr_decl = iter.statements[0]
            assert isinstance(ctr_decl, PsDeclaration)
            assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
    
            cond = iter.statements[1]
            assert isinstance(cond, PsConditional)
            assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
    
            subblock = cond.branch_true
            assert isinstance(subblock.statements[0], PsDeclaration)
            assert subblock.statements[0].declared_symbol.name == f"x__{i}"
    
        assert peeled_loop.start.structurally_equal(loop.start)
        assert peeled_loop.stop.structurally_equal(
            factory.loops_from_ispace(ispace, loop_body).stop
            - factory.parse_index(num_iters)
        )
        assert peeled_loop.body.structurally_equal(loop.body)