Skip to content
Snippets Groups Projects
test_add_pragmas.py 1.75 KiB
import sympy as sp
from itertools import product

from pystencils import make_slice, fields, Assignment
from pystencils.backend.kernelcreation import (
    KernelCreationContext,
    AstFactory,
    FullIterationSpace,
)

from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop
from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma

def test_insert_pragmas():
    ctx = KernelCreationContext()
    factory = AstFactory(ctx)

    f, g = fields("f, g: [3D]")
    ispace = FullIterationSpace.create_from_slice(
        ctx, make_slice[:, :, :], archetype_field=f
    )
    ctx.set_iteration_space(ispace)

    stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1]))
    loop_body = PsBlock([
        factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil))))
    ])
    loops = factory.loops_from_ispace(ispace, loop_body)
    
    pragmas = (
        LoopPragma("omp parallel for", 0),
        LoopPragma("some nonsense pragma", 1),
        LoopPragma("omp simd", -1),
    )
    add_pragmas = InsertPragmasAtLoops(ctx, pragmas)
    ast = add_pragmas(loops)

    assert isinstance(ast, PsBlock)
    
    first_pragma = ast.statements[0]
    assert isinstance(first_pragma, PsPragma)
    assert first_pragma.text == pragmas[0].text

    assert ast.statements[1] == loops
    second_pragma = loops.body.statements[0]
    assert isinstance(second_pragma, PsPragma)
    assert second_pragma.text == pragmas[1].text

    second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1]
    assert isinstance(second_loop, PsLoop)
    third_pragma = second_loop.body.statements[0]
    assert isinstance(third_pragma, PsPragma)
    assert third_pragma.text == pragmas[2].text