Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
import sympy as sp
from itertools import product
from pystencils import make_slice, fields, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop
from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma
def test_insert_pragmas():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
f, g = fields("f, g: [3D]")
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[:, :, :], archetype_field=f
)
ctx.set_iteration_space(ispace)
stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1]))
loop_body = PsBlock([
factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil))))
])
loops = factory.loops_from_ispace(ispace, loop_body)
pragmas = (
LoopPragma("omp parallel for", 0),
LoopPragma("some nonsense pragma", 1),
LoopPragma("omp simd", -1),
)
add_pragmas = InsertPragmasAtLoops(ctx, pragmas)
ast = add_pragmas(loops)
assert isinstance(ast, PsBlock)
first_pragma = ast.statements[0]
assert isinstance(first_pragma, PsPragma)
assert first_pragma.text == pragmas[0].text
assert ast.statements[1] == loops
second_pragma = loops.body.statements[0]
assert isinstance(second_pragma, PsPragma)
assert second_pragma.text == pragmas[1].text
second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1]
assert isinstance(second_loop, PsLoop)
third_pragma = second_loop.body.statements[0]
assert isinstance(third_pragma, PsPragma)
assert third_pragma.text == pragmas[2].text
......@@ -4,12 +4,18 @@ from pystencils.backend.kernelcreation import (
Typifier,
AstFactory,
)
from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.ast.expressions import (
PsExpression,
PsEq,
PsGe,
PsGt,
PsLe,
PsLt,
)
from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment
from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateBranches
from pystencils.types.quick import Int
from pystencils.backend.ast.expressions import PsGt
i0 = PsExpression.make(PsConstant(0, Int(32)))
......@@ -53,3 +59,39 @@ def test_eliminate_nested_conditional():
result = elim(ast)
assert result.body.statements[0].body.statements[0] == b1
def test_isl():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
j = PsExpression.make(ctx.get_symbol("j", ctx.index_dtype))
const_2 = PsExpression.make(PsConstant(2, ctx.index_dtype))
const_4 = PsExpression.make(PsConstant(4, ctx.index_dtype))
a_true = PsBlock([PsComment("a true")])
a_false = PsBlock([PsComment("a false")])
b_true = PsBlock([PsComment("b true")])
b_false = PsBlock([PsComment("b false")])
c_true = PsBlock([PsComment("c true")])
c_false = PsBlock([PsComment("c false")])
a = PsConditional(PsLt(i + j, const_2 * const_4), a_true, a_false)
b = PsConditional(PsGe(j, const_4), b_true, b_false)
c = PsConditional(PsEq(i, const_4), c_true, c_false)
outer_loop = factory.loop(j.symbol.name, slice(0, 3), PsBlock([a, b, c]))
outer_cond = typify(
PsConditional(PsLe(i, const_4), PsBlock([outer_loop]), PsBlock([]))
)
ast = outer_cond
result = elim(ast)
assert result.branch_true.statements[0].body.statements[0] == a_true
assert result.branch_true.statements[0].body.statements[1] == b_false
assert result.branch_true.statements[0].body.statements[2] == c