Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 1865 additions and 13 deletions
import sympy as sp
import pytest
from pystencils import Assignment, TypedSymbol, fields, FieldType, make_slice
from pystencils.sympyextensions import tcast, mem_acc
from pystencils.sympyextensions.pointers import AddressOf
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
Typifier,
)
from pystencils.backend.transformations import (
VectorizationAxis,
VectorizationContext,
AstVectorizer,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import (
PsBlock,
PsDeclaration,
PsAssignment,
PsLoop,
)
from pystencils.backend.ast.expressions import (
PsSymbolExpr,
PsConstantExpr,
PsExpression,
PsCast,
PsMemAcc,
PsCall,
PsSubscript,
)
from pystencils.backend.functions import CFunction
from pystencils.backend.ast.vector import PsVecBroadcast, PsVecMemAcc
from pystencils.backend.exceptions import VectorizationError
from pystencils.types import PsArrayType, PsVectorType, deconstify, create_type
def test_vectorize_expressions():
x, y, z, w = sp.symbols("x, y, z, w")
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
for s in (x, y, z, w):
_ = factory.parse_sympy(s)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
vc.vectorize_symbol(ctx.get_symbol("x"))
vc.vectorize_symbol(ctx.get_symbol("w"))
vectorize = AstVectorizer(ctx)
for expr in [
factory.parse_sympy(-x * y + 13 * z - 4 * (x / w) * (x + z)),
factory.parse_sympy(sp.sin(x + z) - sp.cos(w)),
factory.parse_sympy(y**2 - x**2),
typify(
-factory.parse_sympy(x / (w**2))
), # place the negation outside, since SymPy would remove it
factory.parse_sympy(13 + (1 / w) - sp.exp(x) * 24),
]:
vec_expr = vectorize.visit(expr, vc)
# Must be a clone
assert vec_expr is not expr
scalar_type = ctx.default_dtype
vector_type = PsVectorType(scalar_type, 4)
for subexpr in dfs_preorder(vec_expr):
match subexpr:
case PsSymbolExpr(symb) if symb.name in "yz":
# These are not vectorized, but broadcast
assert symb.dtype == scalar_type
assert subexpr.dtype == scalar_type
case PsConstantExpr(c):
assert deconstify(c.get_dtype()) == scalar_type
assert subexpr.dtype == scalar_type
case PsSymbolExpr(symb):
assert symb.name not in "xw"
assert symb.get_dtype() == vector_type
assert subexpr.dtype == vector_type
case PsVecBroadcast(lanes, operand):
assert lanes == 4
assert subexpr.dtype == vector_type
assert subexpr.dtype.scalar_type == operand.dtype
case PsExpression():
# All other expressions are vectorized
assert subexpr.dtype == vector_type
def test_vectorize_casts_and_counter():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
vec_ctr = ctx.get_symbol("vec_ctr", PsVectorType(ctx.index_dtype, 4))
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(ctr, vec_ctr)
vc = VectorizationContext(ctx, 4, axis)
expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32")))
vec_expr = vectorize.visit(expr, vc)
assert isinstance(vec_expr, PsCast)
assert (
vec_expr.dtype
== vec_expr.target_type
== PsVectorType(create_type("float32"), 4)
)
assert isinstance(vec_expr.operand, PsSymbolExpr)
assert vec_expr.operand.symbol == vec_ctr
assert vec_expr.operand.dtype == PsVectorType(ctx.index_dtype, 4)
def test_invalid_vectorization():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
expr = factory.parse_sympy(tcast(sp.Symbol("ctr"), create_type("float32")))
with pytest.raises(VectorizationError):
# Fails since no vectorized counter was specified
_ = vectorize.visit(expr, vc)
expr = PsExpression.make(
ctx.get_symbol("x_v", PsVectorType(create_type("float32"), 4))
)
with pytest.raises(VectorizationError):
# Fails since this symbol is already vectorial
_ = vectorize.visit(expr, vc)
func = CFunction("compute", [ctx.default_dtype], ctx.default_dtype)
expr = typify(PsCall(func, [PsExpression.make(ctx.get_symbol("x"))]))
with pytest.raises(VectorizationError):
# Can't vectorize unknown function
_ = vectorize.visit(expr, vc)
def test_vectorize_declarations():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
x, y, z, w = sp.symbols("x, y, z, w")
ctr = TypedSymbol("ctr", ctx.index_dtype)
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
ctx.get_symbol("vec_ctr", PsVectorType(ctx.index_dtype, 4)),
)
vc = VectorizationContext(ctx, 4, axis)
block = PsBlock(
[
factory.parse_sympy(asm)
for asm in [
Assignment(x, tcast.as_numeric(ctr)),
Assignment(y, sp.cos(x)),
Assignment(z, x**2 + 2 * y / 4),
Assignment(w, -x + y - z),
]
]
)
vec_block = vectorize.visit(block, vc)
assert vec_block is not block
assert isinstance(vec_block, PsBlock)
for symb_name, decl in zip("xyzw", vec_block.statements):
symb = ctx.get_symbol(symb_name)
assert symb in vc.vectorized_symbols
assert isinstance(decl, PsDeclaration)
assert decl.declared_symbol == vc.vectorized_symbols[symb]
assert (
decl.lhs.dtype
== decl.declared_symbol.dtype
== PsVectorType(ctx.default_dtype, 4)
)
def test_duplicate_declarations():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
x, y = sp.symbols("x, y")
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
block = PsBlock(
[
factory.parse_sympy(asm)
for asm in [
Assignment(y, sp.cos(x)),
Assignment(y, 21),
]
]
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(block, vc)
def test_reject_symbol_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
x, y = sp.symbols("x, y")
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
asm = PsAssignment(factory.parse_sympy(x), factory.parse_sympy(3 + y))
with pytest.raises(VectorizationError):
_ = vectorize.visit(asm, vc)
def test_vectorize_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
x, y = sp.symbols("x, y")
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
decl = PsDeclaration(factory.parse_sympy(x), factory.parse_sympy(sp.sympify(0)))
asm = PsAssignment(factory.parse_sympy(x), factory.parse_sympy(3 + y))
ast = PsBlock([decl, asm])
vec_ast = vectorize.visit(ast, vc)
vec_asm = vec_ast.statements[1]
assert isinstance(vec_asm, PsAssignment)
assert isinstance(vec_asm.lhs.symbol.dtype, PsVectorType)
def test_vectorize_memory_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
vectorize = AstVectorizer(ctx)
x, y = sp.symbols("x, y")
ctr = TypedSymbol("ctr", ctx.index_dtype)
i = TypedSymbol("i", ctx.index_dtype)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
ptr = TypedSymbol("ptr", create_type("float64 *"))
asm = typify(
PsAssignment(
factory.parse_sympy(mem_acc(ptr, 3 * ctr + 2)),
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)),
)
)
vec_asm = vectorize.visit(asm, vc)
assert isinstance(vec_asm, PsAssignment)
assert isinstance(vec_asm.lhs, PsVecMemAcc)
field = fields("field(1): [2D]", field_type=FieldType.CUSTOM)
asm = factory.parse_sympy(
Assignment(
field.absolute_access((ctr, i), (0,)),
x + y * field.absolute_access((ctr + 1, i), (0,)),
)
)
vec_asm = vectorize.visit(asm, vc)
assert isinstance(vec_asm, PsAssignment)
assert isinstance(vec_asm.lhs, PsVecMemAcc)
def test_invalid_memory_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
vectorize = AstVectorizer(ctx)
x, y = sp.symbols("x, y")
ctr = TypedSymbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
i = TypedSymbol("i", ctx.index_dtype)
ptr = TypedSymbol("ptr", create_type("float64 *"))
# Cannot vectorize assignment to LHS that does not depend on axis counter
asm = typify(
PsAssignment(
factory.parse_sympy(mem_acc(ptr, 3 * i + 2)),
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)),
)
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(asm, vc)
def test_vectorize_mem_acc():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
vectorize = AstVectorizer(ctx)
ctr = TypedSymbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
i = TypedSymbol("i", ctx.index_dtype)
j = TypedSymbol("j", ctx.index_dtype)
ptr = TypedSymbol("ptr", create_type("float64 *"))
# Lane-invariant index
acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * j))
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecBroadcast)
assert vec_acc.operand is not acc
assert vec_acc.operand.structurally_equal(acc)
# Counter as index
acc = factory.parse_sympy(mem_acc(ptr, ctr))
assert isinstance(acc, PsMemAcc)
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.pointer is not acc.pointer
assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride is None
assert vec_acc.vector_entries == 4
# Simple affine
acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr))
assert isinstance(acc, PsMemAcc)
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.pointer is not acc.pointer
assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride.structurally_equal(factory.parse_index(5))
assert vec_acc.vector_entries == 4
# More complex, nested affine
acc = factory.parse_sympy(mem_acc(ptr, j * i + 2 * (5 + j * ctr) + 2 * ctr))
assert isinstance(acc, PsMemAcc)
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.pointer is not acc.pointer
assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride.structurally_equal(factory.parse_index(2 * j + 2))
assert vec_acc.vector_entries == 4
# Even more complex affine
idx = -factory.parse_index(ctr) / factory.parse_index(i) - factory.parse_index(
ctr
) * factory.parse_index(j)
acc = typify(PsMemAcc(factory.parse_sympy(ptr), idx))
assert isinstance(acc, PsMemAcc)
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.pointer is not acc.pointer
assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride.structurally_equal(
factory.parse_index(-1) / factory.parse_index(i) - factory.parse_index(j)
)
assert vec_acc.vector_entries == 4
# Mixture of strides in affine and axis
vc = VectorizationContext(
ctx, 4, VectorizationAxis(ctx.get_symbol("ctr"), step=factory.parse_index(3))
)
acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr))
assert isinstance(acc, PsMemAcc)
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.pointer is not acc.pointer
assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride.structurally_equal(factory.parse_index(15))
assert vec_acc.vector_entries == 4
def test_invalid_mem_acc():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = TypedSymbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
i = TypedSymbol("i", ctx.index_dtype)
j = TypedSymbol("j", ctx.index_dtype)
ptr = TypedSymbol("ptr", create_type("float64 *"))
# Non-symbol pointer
acc = factory.parse_sympy(
mem_acc(AddressOf(mem_acc(ptr, 10)), 3 * i + ctr * (3 + ctr))
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
# Non-affine index
acc = factory.parse_sympy(mem_acc(ptr, 3 * i + ctr * (3 + ctr)))
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
# Non lane-invariant index
vc.vectorize_symbol(ctx.get_symbol("j", ctx.index_dtype))
acc = factory.parse_sympy(mem_acc(ptr, 3 * j + ctr))
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
def test_vectorize_buffer_acc():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
field = fields("f(3): [3D]", layout="fzyx")
ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field=field)
ctx.set_iteration_space(ispace)
ctr = ispace.dimensions_in_loop_order()[-1].counter
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
buf = ctx.get_buffer(field)
acc = factory.parse_sympy(field[-1, -1, -1](2))
# Buffer strides are symbolic -> expect strided access
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.stride is not None
assert vec_acc.stride.structurally_equal(PsExpression.make(buf.strides[0]))
# Set buffer stride to one
buf.strides[0] = PsConstant(1, dtype=ctx.index_dtype)
# Expect non-strided access
vec_acc = vectorize.visit(acc, vc)
assert isinstance(vec_acc, PsVecMemAcc)
assert vec_acc.stride is None
def test_invalid_buffer_acc():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
field = fields("field(3): [3D]", field_type=FieldType.CUSTOM)
ctr, i, j = [TypedSymbol(n, ctx.index_dtype) for n in ("ctr", "i", "j")]
axis = VectorizationAxis(ctx.get_symbol("ctr", ctx.index_dtype))
vc = VectorizationContext(ctx, 4, axis)
# Counter occurs in more than one index
acc = factory.parse_sympy(field.absolute_access((ctr, i, ctr + j), (1,)))
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
# Counter occurs in index dimension
acc = factory.parse_sympy(field.absolute_access((ctr, i, j), (ctr,)))
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
# Counter occurs quadratically
acc = factory.parse_sympy(field.absolute_access(((ctr + i) * ctr, i, j), (1,)))
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
def test_vectorize_subscript():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
acc = PsSubscript(
PsExpression.make(ctx.get_symbol("arr", PsArrayType(ctx.default_dtype, 42))),
[PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))],
) # independent of vectorization axis
vec_acc = vectorize.visit(factory._typify(acc), vc)
assert isinstance(vec_acc, PsVecBroadcast)
assert isinstance(vec_acc.operand, PsSubscript)
def test_invalid_subscript():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
acc = PsSubscript(
PsExpression.make(ctx.get_symbol("arr", PsArrayType(ctx.default_dtype, 42))),
[PsExpression.make(ctr)], # depends on vectorization axis
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(factory._typify(acc), vc)
def test_vectorize_nested_loop():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("i", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
ast = factory.loop_nest(
("i", "j"),
make_slice[:8, :8], # inner loop does not depend on vectorization axis
PsBlock(
[
PsDeclaration(
PsExpression.make(ctx.get_symbol("x", ctx.default_dtype)),
PsExpression.make(PsConstant(42, ctx.default_dtype)),
)
]
),
)
vec_ast = vectorize.visit(ast, vc)
inner_loop = next(
dfs_preorder(
vec_ast,
lambda node: isinstance(node, PsLoop) and node.counter.symbol.name == "j",
)
)
decl = inner_loop.body.statements[0]
assert inner_loop.step.structurally_equal(
PsExpression.make(PsConstant(1, ctx.index_dtype))
)
assert isinstance(decl.lhs.symbol.dtype, PsVectorType)
def test_invalid_nested_loop():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("i", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
ast = factory.loop_nest(
("i", "j"),
make_slice[:8, :ctr], # inner loop depends on vectorization axis
PsBlock(
[
PsDeclaration(
PsExpression.make(ctx.get_symbol("x", ctx.default_dtype)),
PsExpression.make(PsConstant(42, ctx.default_dtype)),
)
]
),
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(ast, vc)
import pytest
from pystencils import make_slice
from pystencils.backend.kernelcreation import (
KernelCreationContext,
Typifier,
AstFactory,
)
from pystencils.backend.ast.expressions import (
PsExpression,
PsEq,
PsGe,
PsGt,
PsLe,
PsLt,
)
from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment
from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateBranches
from pystencils.types.quick import Int
i0 = PsExpression.make(PsConstant(0, Int(32)))
i1 = PsExpression.make(PsConstant(1, Int(32)))
def test_eliminate_conditional():
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
b1 = PsBlock([PsComment("Branch One")])
b2 = PsBlock([PsComment("Branch Two")])
cond = typify(PsConditional(PsGt(i1, i0), b1, b2))
result = elim(cond)
assert result == b1
cond = typify(PsConditional(PsGt(-i1, i0), b1, b2))
result = elim(cond)
assert result == b2
cond = typify(PsConditional(PsGt(-i1, i0), b1))
result = elim(cond)
assert result.structurally_equal(PsBlock([]))
def test_eliminate_nested_conditional():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
b1 = PsBlock([PsComment("Branch One")])
b2 = PsBlock([PsComment("Branch Two")])
cond = typify(PsConditional(PsGt(i1, i0), b1, b2))
ast = factory.loop_nest(("i", "j"), make_slice[:10, :10], PsBlock([cond]))
result = elim(ast)
assert result.body.statements[0].body.statements[0] == b1
def test_isl():
pytest.importorskip("islpy")
ctx = KernelCreationContext()
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateBranches(ctx)
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
j = PsExpression.make(ctx.get_symbol("j", ctx.index_dtype))
const_2 = PsExpression.make(PsConstant(2, ctx.index_dtype))
const_4 = PsExpression.make(PsConstant(4, ctx.index_dtype))
a_true = PsBlock([PsComment("a true")])
a_false = PsBlock([PsComment("a false")])
b_true = PsBlock([PsComment("b true")])
b_false = PsBlock([PsComment("b false")])
c_true = PsBlock([PsComment("c true")])
c_false = PsBlock([PsComment("c false")])
a = PsConditional(PsLt(i + j, const_2 * const_4), a_true, a_false)
b = PsConditional(PsGe(j, const_4), b_true, b_false)
c = PsConditional(PsEq(i, const_4), c_true, c_false)
outer_loop = factory.loop(j.symbol.name, slice(0, 3), PsBlock([a, b, c]))
outer_cond = typify(
PsConditional(PsLe(i, const_4), PsBlock([outer_loop]), PsBlock([]))
)
ast = outer_cond
result = elim(ast)
assert result.branch_true.statements[0].body.statements[0] == a_true
assert result.branch_true.statements[0].body.statements[1] == b_false
assert result.branch_true.statements[0].body.statements[2] == c
import sympy as sp
from pystencils import Field, Assignment, make_slice, TypedSymbol
from pystencils.types.quick import Arr
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import CanonicalClone
from pystencils.backend.ast.structural import PsBlock, PsComment
from pystencils.backend.ast.expressions import PsSymbolExpr
from pystencils.backend.ast.iteration import dfs_preorder
def test_clone_entire_ast():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canon_clone = CanonicalClone(ctx)
f = Field.create_generic("f", 2, index_shape=(5,))
rho = sp.Symbol("rho")
u = sp.symbols("u_:2")
cx = TypedSymbol("cx", Arr(ctx.default_dtype, (5,)))
cy = TypedSymbol("cy", Arr(ctx.default_dtype, (5,)))
cxs = sp.IndexedBase(cx, shape=(5,))
cys = sp.IndexedBase(cy, shape=(5,))
rho_out = Field.create_generic("rho", 2, index_shape=(1,))
u_out = Field.create_generic("u", 2, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
ctx.set_iteration_space(ispace)
asms = [
Assignment(cx, (0, 1, -1, 0, 0)),
Assignment(cy, (0, 0, 0, 1, -1)),
Assignment(rho, sum(f.center(i) for i in range(5))),
Assignment(u[0], 1 / rho * sum((f.center(i) * cxs[i]) for i in range(5))),
Assignment(u[1], 1 / rho * sum((f.center(i) * cys[i]) for i in range(5))),
Assignment(rho_out.center(0), rho),
Assignment(u_out.center(0), u[0]),
Assignment(u_out.center(1), u[1]),
]
body = PsBlock(
[PsComment("Compute and export density and velocity")]
+ [factory.parse_sympy(asm) for asm in asms]
)
ast = factory.loops_from_ispace(ispace, body)
ast_clone = canon_clone(ast)
for orig, clone in zip(dfs_preorder(ast), dfs_preorder(ast_clone), strict=True):
assert type(orig) is type(clone)
assert orig is not clone
if isinstance(orig, PsSymbolExpr):
assert isinstance(clone, PsSymbolExpr)
if orig.symbol.name in ("ctr_0", "ctr_1", "rho", "u_0", "u_1", "cx", "cy"):
assert clone.symbol.name == orig.symbol.name + "__0"
# type: ignore
import sympy as sp
from pystencils import Field, Assignment, AddAugmentedAssignment, make_slice, DEFAULTS
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import CanonicalizeSymbols
from pystencils.backend.ast.structural import PsConditional, PsBlock
def test_deduplication():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
f = Field.create_fixed_size("f", (5, 5), memory_strides=(5, 1))
x, y, z = sp.symbols("x, y, z")
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
ctx.set_iteration_space(ispace)
ctr_1 = DEFAULTS.spatial_counters[1]
then_branch = PsBlock(
[
factory.parse_sympy(Assignment(x, y)),
factory.parse_sympy(Assignment(f.center(0), x)),
]
)
else_branch = PsBlock(
[
factory.parse_sympy(Assignment(x, z)),
factory.parse_sympy(Assignment(f.center(0), x)),
]
)
ast = PsConditional(
factory.parse_sympy(ctr_1),
then_branch,
else_branch,
)
ast = factory.loops_from_ispace(ispace, PsBlock([ast]))
ast = canonicalize(ast)
assert canonicalize.get_last_live_symbols() == {
ctx.find_symbol("y"),
ctx.find_symbol("z"),
ctx.get_buffer(f).base_pointer,
}
assert ctx.find_symbol("x") is not None
assert ctx.find_symbol("x__0") is not None
assert then_branch.statements[0].declared_symbol.name == "x__0"
assert then_branch.statements[1].rhs.symbol.name == "x__0"
assert else_branch.statements[0].declared_symbol.name == "x"
assert else_branch.statements[1].rhs.symbol.name == "x"
assert ctx.find_symbol("x").dtype.const
assert ctx.find_symbol("x__0").dtype.const
assert ctx.find_symbol("y").dtype.const
assert ctx.find_symbol("z").dtype.const
def test_do_not_constify():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
x, z = sp.symbols("x, z")
ast = factory.loop("i", make_slice[:10], PsBlock([
factory.parse_sympy(Assignment(x, z)),
factory.parse_sympy(AddAugmentedAssignment(z, 1))
]))
ast = canonicalize(ast)
assert ctx.find_symbol("x").dtype.const
assert not ctx.find_symbol("z").dtype.const
def test_loop_counters():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
f = Field.create_generic("f", 2, index_shape=(1,))
g = Field.create_generic("g", 2, index_shape=(1,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f)
ctx.set_iteration_space(ispace)
asm = Assignment(f.center(0), 2 * g.center(0))
body = PsBlock([factory.parse_sympy(asm)])
loops = factory.loops_from_ispace(ispace, body)
loops_clone = loops.clone()
loops_clone2 = loops.clone()
ast = PsBlock([loops, loops_clone, loops_clone2])
ast = canonicalize(ast)
assert loops_clone2.counter.symbol.name == "ctr_0"
assert not loops_clone2.counter.symbol.get_dtype().const
assert loops_clone.counter.symbol.name == "ctr_0__0"
assert not loops_clone.counter.symbol.get_dtype().const
assert loops.counter.symbol.name == "ctr_0__1"
assert not loops.counter.symbol.get_dtype().const
from typing import Any
import pytest
import numpy as np
import sympy as sp
from pystencils import TypedSymbol, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
Typifier,
AstFactory,
)
from pystencils.backend.ast.structural import PsBlock, PsDeclaration
from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr
from pystencils.backend.memory import PsSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.transformations import EliminateConstants
from pystencils.backend.ast.expressions import (
PsAnd,
PsOr,
PsNot,
PsEq,
PsGt,
PsTernary,
PsRem,
PsIntDiv,
PsCast
)
from pystencils.types.quick import Int, Fp, Bool
from pystencils.types import PsVectorType, create_numeric_type, constify, create_type
class Exprs:
def __init__(self, mode: str):
self.mode = mode
if mode == "scalar":
self._itype = Int(32)
self._ftype = Fp(32)
self._btype = Bool()
else:
self._itype = PsVectorType(Int(32), 4)
self._ftype = PsVectorType(Fp(32), 4)
self._btype = PsVectorType(Bool(), 4)
self.x, self.y, self.z = [
PsExpression.make(PsSymbol(name, self._ftype)) for name in "xyz"
]
self.p, self.q, self.r = [
PsExpression.make(PsSymbol(name, self._itype)) for name in "pqr"
]
self.a, self.b, self.c = [
PsExpression.make(PsSymbol(name, self._btype)) for name in "abc"
]
self.true = PsExpression.make(PsConstant(True, self._btype))
self.false = PsExpression.make(PsConstant(False, self._btype))
def __call__(self, val) -> PsExpression:
match val:
case int():
return PsExpression.make(PsConstant(val, self._itype))
case float():
return PsExpression.make(PsConstant(val, self._ftype))
case np.ndarray():
return PsExpression.make(
PsConstant(
val, PsVectorType(create_numeric_type(val.dtype), len(val))
)
)
case _:
raise ValueError()
@pytest.fixture(scope="module", params=["scalar", "vector"])
def exprs(request):
return Exprs(request.param)
def test_idempotence(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(e(42.0) * (e(1.0) + e(0.0)) - e(0.0))
result = elim(expr)
assert isinstance(result, PsConstantExpr) and result.structurally_equal(e(42.0))
expr = typify((e.x + e(0.0)) * e(3.5) + (e(1.0) * e.y + e(0.0)) * e(42.0))
result = elim(expr)
assert result.structurally_equal(e.x * e(3.5) + e.y * e(42.0))
expr = typify((e(3.5) * e(1.0)) + (e(42.0) * e(1.0)))
result = elim(expr)
# do not fold floats by default
assert expr.structurally_equal(e(3.5) + e(42.0))
expr = typify(e(1.0) * e.x + e(0.0) + (e(0.0) + e(0.0) + e(1.0) + e(0.0)) * e.y)
result = elim(expr)
assert result.structurally_equal(e.x + e.y)
expr = typify(e(0.0) - e(3.2))
result = elim(expr)
assert result.structurally_equal(-e(3.2))
def test_int_folding(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify((e(1) * e.p + e(1) * -e(3)) + e(1) * e(12))
result = elim(expr)
assert result.structurally_equal((e.p + e(-3)) + e(12))
expr = typify((e(1) + e(1) + e(1) + e(0) + e(0) + e(1)) * (e(1) + e(1) + e(1)))
result = elim(expr)
assert result.structurally_equal(e(12))
def test_zero_dominance(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify((e(0.0) * e.x) + (e.y * e(0.0)) + e(1.0))
result = elim(expr)
assert result.structurally_equal(e(1.0))
expr = typify((e(3) + e(12) * (e.p + e.q) + e.p / (e(3) * e.q)) * e(0))
result = elim(expr)
assert result.structurally_equal(e(0))
def test_divisions(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(e(3.5) / e(1.0))
result = elim(expr)
assert result.structurally_equal(e(3.5))
expr = typify(e(3) / e(1))
result = elim(expr)
assert result.structurally_equal(e(3))
expr = typify(PsRem(e(3), e(1)))
result = elim(expr)
assert result.structurally_equal(e(0))
expr = typify(PsIntDiv(e(12), e(3)))
result = elim(expr)
assert result.structurally_equal(e(4))
expr = typify(e(12) / e(3))
result = elim(expr)
assert result.structurally_equal(e(4))
expr = typify(PsIntDiv(e(4), e(3)))
result = elim(expr)
assert result.structurally_equal(e(1))
expr = typify(PsIntDiv(-e(4), e(3)))
result = elim(expr)
assert result.structurally_equal(e(-1))
expr = typify(PsIntDiv(e(4), -e(3)))
result = elim(expr)
assert result.structurally_equal(e(-1))
expr = typify(PsIntDiv(-e(4), -e(3)))
result = elim(expr)
assert result.structurally_equal(e(1))
expr = typify(PsRem(e(4), e(3)))
result = elim(expr)
assert result.structurally_equal(e(1))
expr = typify(PsRem(-e(4), e(3)))
result = elim(expr)
assert result.structurally_equal(e(-1))
expr = typify(PsRem(e(4), -e(3)))
result = elim(expr)
assert result.structurally_equal(e(1))
expr = typify(PsRem(-e(4), -e(3)))
result = elim(expr)
assert result.structurally_equal(e(-1))
def test_fold_floats(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx, fold_floats=True)
expr = typify(e(8.0) / e(2.0))
result = elim(expr)
assert result.structurally_equal(e(4.0))
expr = typify(e(3.0) * e(12.0) / e(6.0))
result = elim(expr)
assert result.structurally_equal(e(6.0))
def test_boolean_folding(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(PsNot(PsAnd(e.false, PsOr(e.true, e.a))))
result = elim(expr)
assert result.structurally_equal(e.true)
expr = typify(PsOr(PsAnd(e.a, e.b), PsNot(e.false)))
result = elim(expr)
assert result.structurally_equal(e.true)
expr = typify(PsAnd(e.c, PsAnd(e.true, PsAnd(e.a, PsOr(e.false, e.b)))))
result = elim(expr)
assert result.structurally_equal(PsAnd(e.c, PsAnd(e.a, e.b)))
expr = typify(PsAnd(e.false, PsAnd(e.c, e.a)))
result = elim(expr)
assert result.structurally_equal(e.false)
expr = typify(PsAnd(PsOr(e.a, e.false), e.false))
result = elim(expr)
assert result.structurally_equal(e.false)
def test_relations_folding(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(PsGt(e.p * e(0), -e(1)))
result = elim(expr)
assert result.structurally_equal(e.true)
expr = typify(PsEq(e(1) + e(1) + e(1), e(3)))
result = elim(expr)
assert result.structurally_equal(e.true)
expr = typify(PsEq(-e(1), -e(3)))
result = elim(expr)
assert result.structurally_equal(e.false)
expr = typify(PsEq(e.x + e.y, e(1.0) * (e.x + e.y)))
result = elim(expr)
assert result.structurally_equal(e.true)
expr = typify(PsGt(e.x + e.y, e(1.0) * (e.x + e.y)))
result = elim(expr)
assert result.structurally_equal(e.false)
def test_ternary_folding():
e = Exprs("scalar")
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(PsTernary(e.true, e.x, e.y))
result = elim(expr)
assert result.structurally_equal(e.x)
expr = typify(PsTernary(e.false, e.x, e.y))
result = elim(expr)
assert result.structurally_equal(e.y)
expr = typify(
PsTernary(PsGt(e(1), e(0)), PsTernary(PsEq(e(1), e(12)), e.x, e.y), e.z)
)
result = elim(expr)
assert result.structurally_equal(e.y)
expr = typify(PsTernary(PsGt(e.x, e.y), e.x + e(0.0), e.y * e(1.0)))
result = elim(expr)
assert result.structurally_equal(PsTernary(PsGt(e.x, e.y), e.x, e.y))
def test_fold_vectors():
e = Exprs("vector")
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx, fold_floats=True)
expr = typify(
e(np.array([1, 3, 2, -4]))
- e(np.array([5, -1, -2, 6])) * e(np.array([1, -1, 1, -1]))
)
result = elim(expr)
assert result.structurally_equal(e(np.array([-4, 2, 4, 2])))
expr = typify(
e(np.array([3.0, 1.0, 2.0, 4.0])) * e(np.array([1.0, -1.0, 1.0, -1.0]))
+ e(np.array([2.0, 3.0, 1.0, 4.0]))
)
result = elim(expr)
assert result.structurally_equal(e(np.array([5.0, 2.0, 3.0, 0.0])))
expr = typify(
PsOr(
PsNot(e(np.array([False, False, True, True]))),
e(np.array([False, True, False, True])),
)
)
result = elim(expr)
assert result.structurally_equal(e(np.array([True, True, False, True])))
def test_fold_casts(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx, fold_floats=True)
target_type = create_type("float16")
if e.mode == "vector":
target_type = PsVectorType(target_type, 4)
expr = typify(PsCast(target_type, e(41.2)))
result = elim(expr)
assert isinstance(result, PsConstantExpr)
np.testing.assert_equal(result.constant.value, e(41.2).constant.value.astype("float16"))
def test_extract_constant_subexprs():
ctx = KernelCreationContext(default_dtype=create_numeric_type("float64"))
factory = AstFactory(ctx)
elim = EliminateConstants(ctx, extract_constant_exprs=True)
x, y, z = sp.symbols("x, y, z")
q, w = TypedSymbol("q", "float32"), TypedSymbol("w", "float32")
block = PsBlock(
[
factory.parse_sympy(Assignment(x, sp.Rational(3, 2))),
factory.parse_sympy(Assignment(y, x + sp.Rational(7, 4))),
factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5))),
factory.parse_sympy(Assignment(q, w + sp.Rational(7, 4))),
factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5) + z * sp.sin(41))),
]
)
result = elim(block)
assert len(result.statements) == 9
c_symb = ctx.find_symbol("__c_3_0o2_0")
assert c_symb is None
c_symb = ctx.find_symbol("__c_7_0o4_0")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
c_symb = ctx.find_symbol("__c_s12_0o5_0")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
# Make sure symbol was duplicated
c_symb = ctx.find_symbol("__c_7_0o4_0__0")
assert c_symb is not None
assert c_symb.dtype == constify(create_numeric_type("float32"))
c_symb = ctx.find_symbol("__c_sin_41_0_")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
def test_extract_vector_constants():
ctx = KernelCreationContext(default_dtype=create_numeric_type("float64"))
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateConstants(ctx, extract_constant_exprs=True)
vtype = PsVectorType(ctx.default_dtype, 8)
x, y, z = TypedSymbol("x", vtype), TypedSymbol("y", vtype), TypedSymbol("z", vtype)
num = typify.typify_expression(
PsExpression.make(
PsConstant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
),
vtype,
)[0]
denom = typify.typify_expression(PsExpression.make(PsConstant(3.0)), vtype)[0]
vconstant = num / denom
block = PsBlock(
[
factory.parse_sympy(Assignment(x, y - sp.Rational(3, 2))),
PsDeclaration(
factory.parse_sympy(z),
typify(factory.parse_sympy(y) + num / denom),
),
]
)
result = elim(block)
assert len(result.statements) == 4
assert isinstance(result.statements[1], PsDeclaration)
assert result.statements[1].rhs.structurally_equal(vconstant)
import sympy as sp
from pystencils import (
Field,
TypedSymbol,
Assignment,
AddAugmentedAssignment,
make_slice,
)
from pystencils.types.quick import Arr, Fp, Bool
from pystencils.backend.ast.structural import (
PsBlock,
PsLoop,
PsConditional,
PsDeclaration,
)
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import (
CanonicalizeSymbols,
HoistLoopInvariantDeclarations,
)
def test_hoist_multiple_loops():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
f = Field.create_fixed_size("f", (5, 5), memory_strides=(5, 1))
x, y, z = sp.symbols("x, y, z")
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], f)
ctx.set_iteration_space(ispace)
first_loop = factory.loops_from_ispace(
ispace,
PsBlock(
[
factory.parse_sympy(Assignment(x, y)),
factory.parse_sympy(Assignment(f.center(0), x)),
]
),
)
second_loop = factory.loops_from_ispace(
ispace,
PsBlock(
[
factory.parse_sympy(Assignment(x, z)),
factory.parse_sympy(Assignment(f.center(0), x)),
]
),
)
ast = PsBlock([first_loop, second_loop])
result = canonicalize(ast)
result = hoist(result)
assert isinstance(result, PsBlock)
assert (
isinstance(result.statements[0], PsDeclaration)
and result.statements[0].declared_symbol.name == "x__0"
)
assert result.statements[1] == first_loop
assert (
isinstance(result.statements[2], PsDeclaration)
and result.statements[2].declared_symbol.name == "x"
)
assert result.statements[3] == second_loop
def test_hoist_with_recurrence():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
x, y = sp.symbols("x, y")
x_decl = factory.parse_sympy(Assignment(x, 1))
x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1))
y_decl = factory.parse_sympy(Assignment(y, 2 * x))
loop = factory.loop("i", make_slice[0:10:1], PsBlock([y_decl, x_update]))
ast = PsBlock([x_decl, loop])
result = hoist(ast)
# x is updated in the loop, so nothing can be hoisted
assert isinstance(result, PsBlock)
assert result.statements == [x_decl, loop]
assert loop.body.statements == [y_decl, x_update]
def test_hoist_with_conditionals():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
x, y, z, w = sp.symbols("x, y, z, w")
x_decl = factory.parse_sympy(Assignment(x, 1))
x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1))
y_decl = factory.parse_sympy(Assignment(y, 2 * x))
z_decl = factory.parse_sympy(Assignment(z, 312))
w_decl = factory.parse_sympy(Assignment(w, 142))
cond = factory.parse_sympy(TypedSymbol("cond", Bool()))
inner_conditional = PsConditional(cond, PsBlock([x_update, z_decl]))
loop = factory.loop(
"i",
make_slice[0:10:1],
PsBlock([y_decl, w_decl, inner_conditional]),
)
outer_conditional = PsConditional(cond, PsBlock([loop]))
ast = PsBlock([x_decl, outer_conditional])
result = hoist(ast)
# z is hidden inside conditional, so z cannot be hoisted
# x is updated conditionally, so y cannot be hoisted
assert isinstance(result, PsBlock)
assert result.statements == [x_decl, outer_conditional]
assert outer_conditional.branch_true.statements == [w_decl, loop]
assert loop.body.statements == [y_decl, inner_conditional]
def test_hoist_arrays():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
const_arr_symb = TypedSymbol(
"const_arr",
Arr(Fp(64), (10,), const=True),
)
const_array_decl = factory.parse_sympy(Assignment(const_arr_symb, tuple(range(10))))
const_arr = sp.IndexedBase(const_arr_symb, shape=(10,))
arr_symb = TypedSymbol(
"arr",
Arr(Fp(64), (10,), const=False),
)
array_decl = factory.parse_sympy(Assignment(arr_symb, tuple(range(10))))
arr = sp.IndexedBase(arr_symb, shape=(10,))
x, y = sp.symbols("x, y")
nonconst_usage = factory.parse_sympy(Assignment(x, arr[3]))
const_usage = factory.parse_sympy(Assignment(y, const_arr[3]))
body = PsBlock([array_decl, const_array_decl, nonconst_usage, const_usage])
loop = factory.loop_nest(("i", "j"), make_slice[:10, :42], body)
result = hoist(loop)
assert isinstance(result, PsBlock)
assert result.statements == [array_decl, const_array_decl, const_usage, loop]
assert isinstance(loop.body.statements[0], PsLoop)
assert loop.body.statements[0].body.statements == [nonconst_usage]
def test_hoisting_eliminates_loops():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
x, y, z = sp.symbols("x, y, z")
invariant_decls = [
factory.parse_sympy(Assignment(x, 42)),
factory.parse_sympy(Assignment(y, 2 * x)),
factory.parse_sympy(Assignment(z, x + 4 * y)),
]
ast = factory.loop_nest(("i", "j"), make_slice[:10, :42], PsBlock(invariant_decls))
ast = hoist(ast)
assert isinstance(ast, PsBlock)
# All statements are hoisted and the loops are removed
assert ast.statements == invariant_decls
def test_hoist_mutation():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
hoist = HoistLoopInvariantDeclarations(ctx)
x = sp.Symbol("x")
x_decl = factory.parse_sympy(Assignment(x, 1))
x_update = factory.parse_sympy(AddAugmentedAssignment(x, 1))
inner_loop = factory.loop("j", slice(10), PsBlock([x_update]))
outer_loop = factory.loop("i", slice(10), PsBlock([x_decl, inner_loop]))
result = hoist(outer_loop)
# x is updated in the loop, so nothing can be hoisted
assert isinstance(result, PsLoop)
assert result.body.statements == [x_decl, inner_loop]
from functools import reduce
from operator import add
from pystencils import fields, Assignment, make_slice, Field, FieldType
from pystencils.types import PsStructType, create_type
from pystencils.backend.memory import BufferBasePtr
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import LowerToC
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.expressions import (
PsBufferAcc,
PsMemAcc,
PsSymbolExpr,
PsExpression,
PsLookup,
PsAddressOf,
PsCast,
)
from pystencils.backend.ast.structural import PsAssignment
def test_lower_buffer_accesses():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:42, :31])
ctx.set_iteration_space(ispace)
lower = LowerToC(ctx)
f, g = fields("f(2), g(3): [2D]")
asm = Assignment(f.center(1), g[-1, 1](2))
f_buf = ctx.get_buffer(f)
g_buf = ctx.get_buffer(g)
fasm = factory.parse_sympy(asm)
assert isinstance(fasm.lhs, PsBufferAcc)
assert isinstance(fasm.rhs, PsBufferAcc)
fasm_lowered = lower(fasm)
assert isinstance(fasm_lowered, PsAssignment)
assert isinstance(fasm_lowered.lhs, PsMemAcc)
assert isinstance(fasm_lowered.lhs.pointer, PsSymbolExpr)
assert fasm_lowered.lhs.pointer.symbol == f_buf.base_pointer
expected_offset = reduce(
add,
(
(PsExpression.make(dm.counter)) * PsExpression.make(stride)
for dm, stride in zip(ispace.dimensions, f_buf.strides)
),
) + PsExpression.make(f_buf.strides[-1])
assert fasm_lowered.lhs.offset.structurally_equal(expected_offset)
assert isinstance(fasm_lowered.rhs, PsMemAcc)
assert isinstance(fasm_lowered.rhs.pointer, PsSymbolExpr)
assert fasm_lowered.rhs.pointer.symbol == g_buf.base_pointer
expected_offset = (
(PsExpression.make(ispace.dimensions[0].counter) + factory.parse_index(-1))
* PsExpression.make(g_buf.strides[0])
+ (PsExpression.make(ispace.dimensions[1].counter) + factory.parse_index(1))
* PsExpression.make(g_buf.strides[1])
+ factory.parse_index(2) * PsExpression.make(g_buf.strides[-1])
)
assert fasm_lowered.rhs.offset.structurally_equal(expected_offset)
def test_lower_anonymous_structs():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:12])
ctx.set_iteration_space(ispace)
lower = LowerToC(ctx)
stype = PsStructType(
[
("val", ctx.default_dtype),
("x", ctx.index_dtype),
]
)
sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype)
f = Field.create_generic("f", 1, ctx.default_dtype, field_type=FieldType.CUSTOM)
asm = Assignment(sfield.center("val"), f.absolute_access((sfield.center("x"),), (0,)))
fasm = factory.parse_sympy(asm)
sbuf = ctx.get_buffer(sfield)
assert isinstance(fasm, PsAssignment)
assert isinstance(fasm.lhs, PsLookup)
lowered_fasm = lower(fasm.clone())
assert isinstance(lowered_fasm, PsAssignment)
# Check type of sfield data pointer
for expr in dfs_preorder(lowered_fasm, lambda n: isinstance(n, PsSymbolExpr)):
if expr.symbol.name == sbuf.base_pointer.name:
assert expr.symbol.dtype == create_type("uint8_t * restrict")
# Check LHS
assert isinstance(lowered_fasm.lhs, PsMemAcc)
assert isinstance(lowered_fasm.lhs.pointer, PsCast)
assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf)
assert isinstance(lowered_fasm.lhs.pointer.operand.operand, PsMemAcc)
type_erased_pointer = lowered_fasm.lhs.pointer.operand.operand.pointer
assert isinstance(type_erased_pointer, PsSymbolExpr)
assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties
assert type_erased_pointer.symbol.dtype == create_type("uint8_t * restrict")
import sympy as sp
from pystencils import Field, Assignment, make_slice
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.transformations import ReshapeLoops
from pystencils.backend.ast.structural import (
PsDeclaration,
PsBlock,
PsLoop,
PsConditional,
)
from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
def test_loop_cutting():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
subloops = reshape.cut_loop(loop, [1, 1, 3])
assert len(subloops) == 3
subloop = subloops[0]
assert isinstance(subloop, PsBlock)
assert isinstance(subloop.statements[0], PsDeclaration)
assert subloop.statements[0].declared_symbol.name == "ctr_0__0"
x_decl = subloop.statements[1]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__0"
subloop = subloops[1]
assert isinstance(subloop, PsLoop)
assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
)
assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
x_decl = subloop.body.statements[0]
assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__1"
subloop = subloops[2]
assert isinstance(subloop, PsLoop)
assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
)
assert subloop.stop.structurally_equal(loop.stop)
def test_loop_peeling():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(
ctx, slice(2, 11, 3), archetype_field=f
)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 2
peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
assert len(peeled_iters) == num_iters
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
ctr_value = {0: 2, 1: 5}[i]
assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsLt(ctr_decl.lhs, loop.stop))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(factory.parse_index(8))
assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body)
def test_loop_peeling_back():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3
peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(loop.start)
assert peeled_loop.stop.structurally_equal(
factory.loops_from_ispace(ispace, loop_body).stop
- factory.parse_index(num_iters)
)
assert peeled_loop.body.structurally_equal(loop.body)
...@@ -46,6 +46,9 @@ def test_3d_arrays(order, alignment, shape): ...@@ -46,6 +46,9 @@ def test_3d_arrays(order, alignment, shape):
@pytest.mark.parametrize("parallel", [False, True]) @pytest.mark.parametrize("parallel", [False, True])
def test_data_handling(parallel): def test_data_handling(parallel):
if parallel:
pytest.importorskip("waLBerla")
for tries in range(16): # try a few times, since we might get lucky and get randomly a correct alignment for tries in range(16): # try a few times, since we might get lucky and get randomly a correct alignment
dh = create_data_handling((6, 7), default_ghost_layers=1, parallel=parallel) dh = create_data_handling((6, 7), default_ghost_layers=1, parallel=parallel)
dh.add_array('test', alignment=8 * 4, values_per_cell=1) dh.add_array('test', alignment=8 * 4, values_per_cell=1)
......
...@@ -8,7 +8,7 @@ import pystencils ...@@ -8,7 +8,7 @@ import pystencils
from pystencils import Assignment, create_kernel from pystencils import Assignment, create_kernel
from pystencils.boundaries import BoundaryHandling, Dirichlet, Neumann, add_neumann_boundary from pystencils.boundaries import BoundaryHandling, Dirichlet, Neumann, add_neumann_boundary
from pystencils.datahandling import SerialDataHandling from pystencils.datahandling import SerialDataHandling
from pystencils.enums import Target from pystencils import Target
from pystencils.slicing import slice_from_direction from pystencils.slicing import slice_from_direction
from pystencils.timeloop import TimeLoop from pystencils.timeloop import TimeLoop
......
import numpy as np import numpy as np
import pytest import pytest
import pystencils import pystencils as ps
def test_dtype_check_wrong_type(): def test_dtype_check_wrong_type():
array = np.ones((10, 20)).astype(np.float32) array = np.ones((10, 20)).astype(np.float32)
output = np.zeros_like(array) output = np.zeros_like(array)
x, y = pystencils.fields('x,y: [2D]') x, y = ps.fields('x,y: [2D]')
stencil = [[1, 1, 1], stencil = [[1, 1, 1],
[1, 1, 1], [1, 1, 1],
[1, 1, 1]] [1, 1, 1]]
assignment = pystencils.assignment.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil)) assignment = ps.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil))
kernel = pystencils.create_kernel([assignment]).compile() kernel = ps.create_kernel([assignment]).compile()
with pytest.raises(ValueError) as e: with pytest.raises(TypeError) as e:
kernel(x=array, y=output) kernel(x=array, y=output)
assert 'Wrong data type' in str(e.value) assert 'Wrong data type' in str(e.value)
...@@ -22,11 +22,11 @@ def test_dtype_check_wrong_type(): ...@@ -22,11 +22,11 @@ def test_dtype_check_wrong_type():
def test_dtype_check_correct_type(): def test_dtype_check_correct_type():
array = np.ones((10, 20)).astype(np.float64) array = np.ones((10, 20)).astype(np.float64)
output = np.zeros_like(array) output = np.zeros_like(array)
x, y = pystencils.fields('x,y: [2D]') x, y = ps.fields('x,y: [2D]')
stencil = [[1, 1, 1], stencil = [[1, 1, 1],
[1, 1, 1], [1, 1, 1],
[1, 1, 1]] [1, 1, 1]]
assignment = pystencils.assignment.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil)) assignment = ps.assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil))
kernel = pystencils.create_kernel([assignment]).compile() kernel = ps.create_kernel([assignment]).compile()
kernel(x=array, y=output) kernel(x=array, y=output)
assert np.allclose(output[1:-1, 1:-1], np.ones_like(output[1:-1, 1:-1])) assert np.allclose(output[1:-1, 1:-1], np.ones_like(output[1:-1, 1:-1]))
datahandling_save_test*
\ No newline at end of file
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import pystencils as ps import pystencils as ps
from pystencils import create_data_handling, create_kernel from pystencils import create_data_handling, create_kernel
from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.gpu.gpu_array_handler import GPUArrayHandler
from pystencils.enums import Target from pystencils import Target
try: try:
import pytest import pytest
...@@ -249,7 +249,7 @@ def test_add_arrays(): ...@@ -249,7 +249,7 @@ def test_add_arrays():
dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy') dh = create_data_handling(domain_size=domain_shape, default_ghost_layers=0, default_layout='numpy')
x_, y_ = dh.add_arrays(field_description) x_, y_ = dh.add_arrays(field_description)
x, y = ps.fields(field_description + ': [3,4,5]') x, y = ps.fields(field_description + ': float64[3,4,5]')
assert x_ == x assert x_ == x
assert y_ == y assert y_ == y
...@@ -408,4 +408,3 @@ def test_array_handler(device_number): ...@@ -408,4 +408,3 @@ def test_array_handler(device_number):
assert cpu_array2.base is not None assert cpu_array2.base is not None
assert gpu_array2.base is not None assert gpu_array2.base is not None
assert gpu_array2.strides == cpu_array2.strides assert gpu_array2.strides == cpu_array2.strides
...@@ -11,7 +11,7 @@ from pystencils.slicing import slice_from_direction ...@@ -11,7 +11,7 @@ from pystencils.slicing import slice_from_direction
from pystencils.datahandling.parallel_datahandling import ParallelDataHandling from pystencils.datahandling.parallel_datahandling import ParallelDataHandling
from pystencils.datahandling import create_data_handling from pystencils.datahandling import create_data_handling
from tests.test_datahandling import ( from test_datahandling import (
access_and_gather, kernel_execution_jacobi, reduction, synchronization, vtk_output) access_and_gather, kernel_execution_jacobi, reduction, synchronization, vtk_output)
SCRIPT_FOLDER = Path(__file__).parent.absolute() SCRIPT_FOLDER = Path(__file__).parent.absolute()
......