Skip to content
Snippets Groups Projects
Select Git revision
  • 65ddbe067630a7fbca093886bbf364833742594e
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

kernel_decorator.py

Blame
  • test_reshape_loops.py 5.06 KiB
    import sympy as sp
    
    from pystencils import Field, Assignment, make_slice
    from pystencils.backend.kernelcreation import (
        KernelCreationContext,
        AstFactory,
        FullIterationSpace,
    )
    from pystencils.backend.transformations import ReshapeLoops
    
    from pystencils.backend.ast.structural import (
        PsDeclaration,
        PsBlock,
        PsLoop,
        PsConditional,
    )
    from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
    
    
    def test_loop_cutting():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
        reshape = ReshapeLoops(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        f = Field.create_generic("f", 1, index_shape=(2,))
        ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
        ctx.set_iteration_space(ispace)
    
        loop_body = PsBlock(
            [
                factory.parse_sympy(Assignment(x, 2 * z)),
                factory.parse_sympy(Assignment(f.center(0), x + y)),
            ]
        )
    
        loop = factory.loops_from_ispace(ispace, loop_body)
    
        subloops = reshape.cut_loop(loop, [1, 1, 3])
        assert len(subloops) == 3
    
        subloop = subloops[0]
        assert isinstance(subloop, PsBlock)
        assert isinstance(subloop.statements[0], PsDeclaration)
        assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
    
        x_decl = subloop.statements[1]
        assert isinstance(x_decl, PsDeclaration)
        assert x_decl.declared_symbol.name == "x__0"
    
        subloop = subloops[1]
        assert isinstance(subloop, PsLoop)
        assert (
            isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
        )
        assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
    
        x_decl = subloop.body.statements[0]
        assert isinstance(x_decl, PsDeclaration)
        assert x_decl.declared_symbol.name == "x__1"
    
        subloop = subloops[2]
        assert isinstance(subloop, PsLoop)
        assert (
            isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
        )
        assert subloop.stop.structurally_equal(loop.stop)
    
    
    def test_loop_peeling():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
        reshape = ReshapeLoops(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        f = Field.create_generic("f", 1, index_shape=(2,))
        ispace = FullIterationSpace.create_from_slice(
            ctx, slice(2, 11, 3), archetype_field=f
        )
        ctx.set_iteration_space(ispace)
    
        loop_body = PsBlock(
            [
                factory.parse_sympy(Assignment(x, 2 * z)),
                factory.parse_sympy(Assignment(f.center(0), x + y)),
            ]
        )
    
        loop = factory.loops_from_ispace(ispace, loop_body)
    
        num_iters = 2
        peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
        assert len(peeled_iters) == num_iters
    
        for i, iter in enumerate(peeled_iters):
            assert isinstance(iter, PsBlock)
    
            ctr_decl = iter.statements[0]
            assert isinstance(ctr_decl, PsDeclaration)
            assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
            ctr_value = {0: 2, 1: 5}[i]
            assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
    
            cond = iter.statements[1]
            assert isinstance(cond, PsConditional)
            assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
    
            subblock = cond.branch_true
            assert isinstance(subblock.statements[0], PsDeclaration)
            assert subblock.statements[0].declared_symbol.name == f"x__{i}"
    
        assert peeled_loop.start.structurally_equal(factory.parse_index(8))
        assert peeled_loop.stop.structurally_equal(loop.stop)
        assert peeled_loop.body.structurally_equal(loop.body)
    
    
    def test_loop_peeling_back():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
        reshape = ReshapeLoops(ctx)
    
        x, y, z = sp.symbols("x, y, z")
    
        f = Field.create_generic("f", 1, index_shape=(2,))
        ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
        ctx.set_iteration_space(ispace)
    
        loop_body = PsBlock(
            [
                factory.parse_sympy(Assignment(x, 2 * z)),
                factory.parse_sympy(Assignment(f.center(0), x + y)),
            ]
        )
    
        loop = factory.loops_from_ispace(ispace, loop_body)
    
        num_iters = 3
        peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
        assert len(peeled_iters) == 3
    
        for i, iter in enumerate(peeled_iters):
            assert isinstance(iter, PsBlock)
    
            ctr_decl = iter.statements[0]
            assert isinstance(ctr_decl, PsDeclaration)
            assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
    
            cond = iter.statements[1]
            assert isinstance(cond, PsConditional)
            assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
    
            subblock = cond.branch_true
            assert isinstance(subblock.statements[0], PsDeclaration)
            assert subblock.statements[0].declared_symbol.name == f"x__{i}"
    
        assert peeled_loop.start.structurally_equal(loop.start)
        assert peeled_loop.stop.structurally_equal(
            factory.loops_from_ispace(ispace, loop_body).stop
            - factory.parse_index(num_iters)
        )
        assert peeled_loop.body.structurally_equal(loop.body)