Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 2729 additions and 0 deletions
import numpy as np
import pytest
from itertools import product
from pystencils import (
create_kernel,
Target,
Assignment,
Field,
)
from pystencils.sympyextensions.typed_sympy import tcast
AVAIL_TARGETS_NO_SSE = [t for t in Target.available_targets() if Target._SSE not in t]
target_and_dtype = pytest.mark.parametrize(
"target, from_type, to_type",
list(
product(
[
t
for t in AVAIL_TARGETS_NO_SSE
if Target._X86 in t and Target._AVX512 not in t
],
[np.int32, np.float32, np.float64],
[np.int32, np.float32, np.float64],
)
)
+ list(
product(
[
t
for t in AVAIL_TARGETS_NO_SSE
if Target._X86 not in t or Target._AVX512 in t
],
[np.int32, np.int64, np.float32, np.float64],
[np.int32, np.int64, np.float32, np.float64],
)
),
)
@target_and_dtype
def test_type_cast(gen_config, xp, from_type, to_type):
if np.issubdtype(from_type, np.floating):
inp = xp.array([-1.25, -0, 1.5, 3, -5, -312, 42, 6.625, -9], dtype=from_type)
else:
inp = xp.array([-1, 0, 1, 3, -5, -312, 42, 6, -9], dtype=from_type)
outp = xp.zeros_like(inp).astype(to_type)
truncated = inp.astype(to_type)
rounded = xp.round(inp).astype(to_type)
inp_field = Field.create_from_numpy_array("inp", inp)
outp_field = Field.create_from_numpy_array("outp", outp)
asms = [Assignment(outp_field.center(), tcast(inp_field.center(), to_type))]
kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile()
kfunc(inp=inp, outp=outp)
if np.issubdtype(from_type, np.floating) and not np.issubdtype(
to_type, np.floating
):
# rounding mode depends on platform
try:
xp.testing.assert_array_equal(outp, truncated)
except AssertionError:
xp.testing.assert_array_equal(outp, rounded)
else:
xp.testing.assert_array_equal(outp, truncated)
import pytest
from pystencils.field import Field
from pystencils.backend.kernelcreation import (
KernelCreationContext,
FullIterationSpace
)
from pystencils.backend.ast.structural import PsBlock, PsLoop, PsComment
from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.platforms import GenericCpu
@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f", (2, 0, 1)])
def test_loop_nest(layout):
ctx = KernelCreationContext()
body = PsBlock([PsComment("Loop body goes here")])
platform = GenericCpu(ctx)
# FZYX Order
archetype_field = Field.create_generic("field", spatial_dimensions=3, layout=layout)
ispace = FullIterationSpace.create_with_ghost_layers(ctx, 0, archetype_field)
loop_nest = platform.materialize_iteration_space(body, ispace)
layout_tuple = archetype_field.layout
dims = [ispace.dimensions[i] for i in layout_tuple]
loops = dfs_preorder(loop_nest, lambda n: isinstance(n, PsLoop))
for loop, dim in zip(loops, dims, strict=True):
assert isinstance(loop, PsLoop)
assert loop.start.structurally_equal(dim.start)
assert loop.stop.structurally_equal(dim.stop)
assert loop.step.structurally_equal(dim.step)
assert loop.counter.structurally_equal(PsExpression.make(dim.counter))
import pytest
from pystencils.field import Field
from pystencils.backend.kernelcreation import (
KernelCreationContext,
FullIterationSpace
)
from pystencils.backend.ast.structural import PsBlock, PsComment
from pystencils.backend.platforms import CudaPlatform, SyclPlatform
@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"])
@pytest.mark.parametrize("platform_class", [CudaPlatform, SyclPlatform])
def test_thread_range(platform_class, layout):
ctx = KernelCreationContext()
body = PsBlock([PsComment("Kernel body goes here")])
platform = platform_class(ctx)
dim = 3
archetype_field = Field.create_generic("field", spatial_dimensions=dim, layout=layout)
ispace = FullIterationSpace.create_with_ghost_layers(ctx, 1, archetype_field)
_, threads_range = platform.materialize_iteration_space(body, ispace)
assert threads_range.dim == dim
match layout:
case "fzyx" | "zyxf" | "f":
indexing_order = [0, 1, 2]
case "c":
indexing_order = [2, 1, 0]
for i in range(dim):
# Slowest to fastest coordinate
coordinate = indexing_order[i]
dimension = ispace.dimensions[coordinate]
witems = threads_range.num_work_items[i]
desired = dimension.stop - dimension.start
assert witems.structurally_equal(desired)
from itertools import chain
import pytest
from pystencils import Field, TypedSymbol, FieldType, DynamicType
from pystencils.backend.kernelcreation import KernelCreationContext
from pystencils.backend.constants import PsConstant
from pystencils.backend.memory import PsSymbol
from pystencils.codegen.properties import FieldShape, FieldStride
from pystencils.backend.exceptions import KernelConstraintsError
from pystencils.types.quick import SInt, Fp
from pystencils.types import deconstify
def test_field_arrays():
ctx = KernelCreationContext(index_dtype=SInt(16))
f = Field.create_generic("f", 3, Fp(32))
f_arr = ctx.get_buffer(f)
assert f_arr.element_type == f.dtype == Fp(32)
assert len(f_arr.shape) == len(f.shape) + 1 == 4
assert isinstance(f_arr.shape[3], PsConstant) and f_arr.shape[3].value == 1
assert f_arr.shape[3].dtype == SInt(16, const=True)
assert f_arr.index_type == ctx.index_dtype == SInt(16)
assert f_arr.shape[0].dtype == ctx.index_dtype == SInt(16)
for i, s in enumerate(f_arr.shape[:1]):
assert isinstance(s, PsSymbol)
assert FieldShape(f, i) in s.properties
for i, s in enumerate(f_arr.strides[:1]):
assert isinstance(s, PsSymbol)
assert FieldStride(f, i) in s.properties
g = Field.create_generic("g", 3, index_shape=(2, 4), dtype=Fp(16))
g_arr = ctx.get_buffer(g)
assert g_arr.element_type == g.dtype == Fp(16)
assert len(g_arr.shape) == len(g.spatial_shape) + len(g.index_shape) == 5
assert isinstance(g_arr.shape[3], PsConstant) and g_arr.shape[3].value == 2
assert g_arr.shape[3].dtype == SInt(16, const=True)
assert isinstance(g_arr.shape[4], PsConstant) and g_arr.shape[4].value == 4
assert g_arr.shape[4].dtype == SInt(16, const=True)
assert g_arr.index_type == ctx.index_dtype == SInt(16)
h = Field(
"h",
FieldType.GENERIC,
Fp(32),
(0, 1),
(TypedSymbol("nx", SInt(32)), TypedSymbol("ny", SInt(32)), 1),
(TypedSymbol("sx", SInt(32)), TypedSymbol("sy", SInt(32)), 1),
)
h_arr = ctx.get_buffer(h)
assert h_arr.index_type == SInt(32)
for s in chain(h_arr.shape, h_arr.strides):
assert deconstify(s.get_dtype()) == SInt(32)
assert [s.name for s in chain(h_arr.shape[:2], h_arr.strides[:2])] == [
"nx",
"ny",
"sx",
"sy",
]
def test_invalid_fields():
ctx = KernelCreationContext(index_dtype=SInt(16))
h = Field(
"h",
FieldType.GENERIC,
Fp(32),
(0,),
(TypedSymbol("nx", SInt(32)),),
(TypedSymbol("sx", SInt(64)),),
)
with pytest.raises(KernelConstraintsError):
_ = ctx.get_buffer(h)
h = Field(
"h",
FieldType.GENERIC,
Fp(32),
(0,),
(TypedSymbol("nx", Fp(32)),),
(TypedSymbol("sx", Fp(32)),),
)
with pytest.raises(KernelConstraintsError):
_ = ctx.get_buffer(h)
h = Field(
"h",
FieldType.GENERIC,
Fp(32),
(0,),
(TypedSymbol("nx", DynamicType.NUMERIC_TYPE),),
(TypedSymbol("sx", DynamicType.NUMERIC_TYPE),),
)
with pytest.raises(KernelConstraintsError):
_ = ctx.get_buffer(h)
def test_duplicate_fields():
f = Field.create_generic("f", 3)
g = f.new_field_with_different_name("g")
# f and g have the same indexing symbols
assert f.shape == g.shape
assert f.strides == g.strides
ctx = KernelCreationContext()
f_buf = ctx.get_buffer(f)
g_buf = ctx.get_buffer(g)
for sf, sg in zip(chain(f_buf.shape, f_buf.strides), chain(g_buf.shape, g_buf.strides)):
# Must be the same
assert sf == sg
for i, s in enumerate(f_buf.shape[:-1]):
assert isinstance(s, PsSymbol)
assert FieldShape(f, i) in s.properties
assert FieldShape(g, i) in s.properties
for i, s in enumerate(f_buf.strides[:-1]):
assert isinstance(s, PsSymbol)
assert FieldStride(f, i) in s.properties
assert FieldStride(g, i) in s.properties
# Base pointers must be different, though!
assert f_buf.base_pointer != g_buf.base_pointer
import sympy as sp
import pytest
from pystencils import (
Assignment,
fields,
create_type,
create_numeric_type,
TypedSymbol,
DynamicType,
)
from pystencils.sympyextensions import tcast
from pystencils.sympyextensions.pointers import mem_acc
from pystencils.backend.ast.structural import (
PsAssignment,
PsDeclaration,
)
from pystencils.backend.ast.expressions import (
PsBufferAcc,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsExpression,
PsTernary,
PsIntDiv,
PsLeftShift,
PsRightShift,
PsAnd,
PsOr,
PsNot,
PsEq,
PsNe,
PsLt,
PsLe,
PsGt,
PsGe,
PsCall,
PsCast,
PsConstantExpr,
PsAdd,
PsMul,
PsSub,
PsArrayInitList,
PsSubscript,
PsMemAcc,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import PsMathFunction, MathFunctions
from pystencils.backend.kernelcreation import (
KernelCreationContext,
FreezeExpressions,
FullIterationSpace,
)
from pystencils.backend.kernelcreation.freeze import FreezeError
from pystencils.sympyextensions.integer_functions import (
bit_shift_left,
bit_shift_right,
bitwise_and,
bitwise_or,
bitwise_xor,
int_div,
int_power_of_2,
round_to_multiple_towards_zero,
ceil_to_multiple,
div_ceil,
)
def test_freeze_simple():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
two = PsExpression.make(PsConstant(2))
should = PsDeclaration(z2, y2 + two * x2)
assert fasm.structurally_equal(should)
assert not fasm.structurally_equal(PsAssignment(z2, two * x2 + y2))
def test_freeze_fields():
ctx = KernelCreationContext()
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
forty_two = PsExpression.make(PsConstant(42, ctx.index_dtype))
one = PsExpression.make(PsConstant(1, ctx.index_dtype))
counter = ctx.get_symbol("ctr", ctx.index_dtype)
ispace = FullIterationSpace(
ctx, [FullIterationSpace.Dimension(zero, forty_two, one, counter)]
)
ctx.set_iteration_space(ispace)
freeze = FreezeExpressions(ctx)
f, g = fields("f, g : [1D]")
asm = Assignment(f.center(0), g.center(0))
f_arr = ctx.get_buffer(f)
g_arr = ctx.get_buffer(g)
fasm = freeze(asm)
zero = PsExpression.make(PsConstant(0))
lhs = PsBufferAcc(f_arr.base_pointer, (PsExpression.make(counter) + zero, zero))
rhs = PsBufferAcc(g_arr.base_pointer, (PsExpression.make(counter) + zero, zero))
should = PsAssignment(lhs, rhs)
assert fasm.structurally_equal(should)
def test_freeze_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
expr = bit_shift_left(
bit_shift_right(bitwise_and(x, y), bitwise_or(y, z)), bitwise_xor(x, z)
)
fexpr = freeze(expr)
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
should = PsLeftShift(
PsRightShift(PsBitwiseAnd(x2, y2), PsBitwiseOr(y2, z2)), PsBitwiseXor(x2, z2)
)
assert fexpr.structurally_equal(should)
def test_freeze_integer_functions():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x2 = PsExpression.make(ctx.get_symbol("x", ctx.index_dtype))
y2 = PsExpression.make(ctx.get_symbol("y", ctx.index_dtype))
z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype))
x, y, z = sp.symbols("x, y, z")
one = PsExpression.make(PsConstant(1))
asms = [
Assignment(z, int_div(x, y)),
Assignment(z, int_power_of_2(x, y)),
Assignment(z, round_to_multiple_towards_zero(x, y)),
Assignment(z, ceil_to_multiple(x, y)),
Assignment(z, div_ceil(x, y)),
]
fasms = [freeze(asm) for asm in asms]
should = [
PsDeclaration(z2, PsIntDiv(x2, y2)),
PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)),
PsDeclaration(z2, PsIntDiv(x2, y2) * y2),
PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2) * y2),
PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2)),
]
for fasm, correct in zip(fasms, should):
assert fasm.structurally_equal(correct)
def test_freeze_booleans():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
w2 = PsExpression.make(ctx.get_symbol("w"))
x, y, z, w = sp.symbols("x, y, z, w")
expr = freeze(sp.Not(sp.And(x, y)))
assert expr.structurally_equal(PsNot(PsAnd(x2, y2)))
expr = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x))))
assert expr.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2))))
expr = freeze(sp.And(w, x, y, z))
assert expr.structurally_equal(PsAnd(PsAnd(PsAnd(w2, x2), y2), z2))
expr = freeze(sp.Or(w, x, y, z))
assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2))
@pytest.mark.parametrize(
"rel_pair",
[
(sp.Eq, PsEq),
(sp.Ne, PsNe),
(sp.Lt, PsLt),
(sp.Gt, PsGt),
(sp.Le, PsLe),
(sp.Ge, PsGe),
],
)
def test_freeze_relations(rel_pair):
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
sp_op, ps_op = rel_pair
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
x, y, z = sp.symbols("x, y, z")
expr1 = freeze(sp_op(x, y + z))
assert expr1.structurally_equal(ps_op(x2, y2 + z2))
def test_freeze_piecewise():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
p, q, x, y, z = sp.symbols("p, q, x, y, z")
p2 = PsExpression.make(ctx.get_symbol("p"))
q2 = PsExpression.make(ctx.get_symbol("q"))
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
piecewise = sp.Piecewise((x, p), (y, q), (z, True))
expr = freeze(piecewise)
assert isinstance(expr, PsTernary)
should = PsTernary(p2, x2, PsTernary(q2, y2, z2))
assert expr.structurally_equal(should)
piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q)))
with pytest.raises(FreezeError):
freeze(piecewise)
def test_multiarg_min_max():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
w, x, y, z = sp.symbols("w, x, y, z")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
w2 = PsExpression.make(ctx.get_symbol("w"))
def op(a, b):
return PsCall(PsMathFunction(MathFunctions.Min), (a, b))
expr = freeze(sp.Min(w, x, y))
assert expr.structurally_equal(op(op(w2, x2), y2))
expr = freeze(sp.Min(w, x, y, z))
assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
def op(a, b):
return PsCall(PsMathFunction(MathFunctions.Max), (a, b))
expr = freeze(sp.Max(w, x, y))
assert expr.structurally_equal(op(op(w2, x2), y2))
expr = freeze(sp.Max(w, x, y, z))
assert expr.structurally_equal(op(op(w2, x2), op(y2, z2)))
def test_dynamic_types():
ctx = KernelCreationContext(
default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16")
)
freeze = FreezeExpressions(ctx)
x, y = [TypedSymbol(n, DynamicType.NUMERIC_TYPE) for n in "xy"]
p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"]
expr = freeze(x + y)
assert ctx.get_symbol("x").dtype == ctx.default_dtype
assert ctx.get_symbol("y").dtype == ctx.default_dtype
expr = freeze(p - q)
assert ctx.get_symbol("p").dtype == ctx.index_dtype
assert ctx.get_symbol("q").dtype == ctx.index_dtype
def test_cast_func():
ctx = KernelCreationContext(
default_dtype=create_numeric_type("float16"), index_dtype=create_type("int16")
)
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
expr = freeze(tcast(x, create_type("int")))
assert expr.structurally_equal(PsCast(create_type("int"), x2))
expr = freeze(tcast.as_numeric(y))
assert expr.structurally_equal(PsCast(ctx.default_dtype, y2))
expr = freeze(tcast.as_index(z))
assert expr.structurally_equal(PsCast(ctx.index_dtype, z2))
expr = freeze(tcast(42, create_type("int16")))
assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16"))))
def test_add_sub():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x = sp.Symbol("x")
y = sp.Symbol("y", negative=True)
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
two = PsExpression.make(PsConstant(2))
minus_two = PsExpression.make(PsConstant(-2))
expr = freeze(x + y)
assert expr.structurally_equal(PsAdd(x2, y2))
expr = freeze(x - y)
assert expr.structurally_equal(PsSub(x2, y2))
expr = freeze(x + 2 * y)
assert expr.structurally_equal(PsAdd(x2, PsMul(two, y2)))
expr = freeze(x - 2 * y)
assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2)))
def test_powers():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
# Integer powers
expr = freeze(x**2)
assert expr.structurally_equal(x2 * x2)
expr = freeze(x**3)
assert expr.structurally_equal(x2 * x2 * x2)
expr = freeze(x**4)
assert expr.structurally_equal((x2 * x2) * (x2 * x2))
expr = freeze(x**5)
assert expr.structurally_equal((x2 * x2) * (x2 * x2) * x2)
# Negative integer powers
one = PsExpression.make(PsConstant(1))
expr = freeze(x**-2)
assert expr.structurally_equal(one / (x2 * x2))
expr = freeze(x**-3)
assert expr.structurally_equal(one / (x2 * x2 * x2))
expr = freeze(x**-4)
assert expr.structurally_equal(one / ((x2 * x2) * (x2 * x2)))
expr = freeze(x**-5)
assert expr.structurally_equal(one / ((x2 * x2) * (x2 * x2) * x2))
# Integer powers of the square root
sqrt = PsMathFunction(MathFunctions.Sqrt)
expr = freeze(x ** sp.Rational(1, 2))
assert expr.structurally_equal(sqrt(x2))
expr = freeze(x ** sp.Rational(2, 2))
assert expr.structurally_equal(x2)
expr = freeze(x ** sp.Rational(3, 2))
assert expr.structurally_equal(sqrt(x2) * sqrt(x2) * sqrt(x2))
expr = freeze(x ** sp.Rational(4, 2))
assert expr.structurally_equal(x2 * x2)
expr = freeze(x ** sp.Rational(5, 2))
assert expr.structurally_equal(
(sqrt(x2) * sqrt(x2)) * (sqrt(x2) * sqrt(x2)) * sqrt(x2)
)
# Negative integer powers of sqrt
expr = freeze(x ** sp.Rational(-1, 2))
assert expr.structurally_equal(one / sqrt(x2))
expr = freeze(x ** sp.Rational(-3, 2))
assert expr.structurally_equal(one / (sqrt(x2) * sqrt(x2) * sqrt(x2)))
# Cube root
pow = PsMathFunction(MathFunctions.Pow)
expr = freeze(x ** sp.Rational(1, 3))
assert expr.structurally_equal(pow(x2, freeze(sp.Rational(1, 3))))
# Unknown exponent
expr = freeze(x**y)
assert expr.structurally_equal(pow(x2, y2))
def test_tuple_array_literals():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
one = PsExpression.make(PsConstant(1))
three = PsExpression.make(PsConstant(3))
four = PsExpression.make(PsConstant(4))
arr_literal = freeze(sp.Tuple(3 + y, z, z / 4))
assert arr_literal.structurally_equal(
PsArrayInitList([three + y2, z2, one / four * z2])
)
def test_nested_tuples():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
def f(n):
return freeze(sp.sympify(n))
shape = (2, 3, 2)
symb_arr = sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10)))
arr_literal = freeze(symb_arr)
assert isinstance(arr_literal, PsArrayInitList)
assert arr_literal.shape == shape
assert arr_literal.structurally_equal(
PsArrayInitList(
[
((f(1), f(2)), (f(3), f(4)), (f(5), f(6))),
((f(5), f(6)), (f(7), f(8)), (f(9), f(10))),
]
)
)
def test_invalid_arrays():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
# invalid: nonuniform nesting depth
symb_arr = sp.Tuple((3, 32), 14)
with pytest.raises(FreezeError):
_ = freeze(symb_arr)
# invalid: nonuniform sub-array length
symb_arr = sp.Tuple((3, 32), (14, -7, 3))
with pytest.raises(FreezeError):
_ = freeze(symb_arr)
# invalid: empty subarray
symb_arr = sp.Tuple((), (0, -9))
with pytest.raises(FreezeError):
_ = freeze(symb_arr)
# invalid: all subarrays empty
symb_arr = sp.Tuple((), ())
with pytest.raises(FreezeError):
_ = freeze(symb_arr)
def test_memory_access():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
ptr = sp.Symbol("ptr")
expr = freeze(mem_acc(ptr, 31))
assert isinstance(expr, PsMemAcc)
assert expr.pointer.structurally_equal(PsExpression.make(ctx.get_symbol("ptr")))
assert expr.offset.structurally_equal(PsExpression.make(PsConstant(31)))
def test_indexed():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
a = sp.IndexedBase("a")
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
a2 = PsExpression.make(ctx.get_symbol("a"))
expr = freeze(a[x, y, z])
assert expr.structurally_equal(PsSubscript(a2, (x2, y2, z2)))
import pytest
import numpy as np
from pystencils import make_slice, Field, create_type
from pystencils.sympyextensions.typed_sympy import TypedSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import (
KernelCreationContext,
FullIterationSpace,
AstFactory,
)
from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression
from pystencils.backend.kernelcreation.typification import TypificationError
def test_slices_over_field():
ctx = KernelCreationContext()
archetype_field = Field.create_generic("f", spatial_dimensions=3, layout="fzyx")
ctx.add_field(archetype_field)
islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1))
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
dims = ispace.dimensions
for sl, dim in zip(islice, dims):
assert (
isinstance(dim.start, PsConstantExpr)
and dim.start.constant.value == sl.start
)
assert (
isinstance(dim.step, PsConstantExpr) and dim.step.constant.value == sl.step
)
assert isinstance(dims[0].stop, PsAdd) and any(
op.structurally_equal(PsExpression.make(archetype_arr.shape[0]))
for op in dims[0].stop.children
)
assert isinstance(dims[1].stop, PsAdd) and any(
op.structurally_equal(PsExpression.make(archetype_arr.shape[1]))
for op in dims[1].stop.children
)
assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2]))
def test_slices_with_fixed_size_field():
ctx = KernelCreationContext()
archetype_field = Field.create_fixed_size("f", (4, 5, 6), layout="fzyx")
ctx.add_field(archetype_field)
islice = (slice(1, -1, 1), slice(3, -3, 3), slice(0, None, 1))
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
dims = ispace.dimensions
for sl, size, dim in zip(islice, archetype_arr.shape, dims):
assert (
isinstance(dim.start, PsConstantExpr)
and dim.start.constant.value == sl.start
)
assert isinstance(size, PsConstant)
assert isinstance(
dim.stop, PsConstantExpr
) and dim.stop.constant.value == np.int64(
size.value + sl.stop if sl.stop is not None else size.value
)
assert (
isinstance(dim.step, PsConstantExpr) and dim.step.constant.value == sl.step
)
def test_singular_slice_over_field():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx")
ctx.add_field(archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
islice = (4, -3)
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
dims = ispace.dimensions
assert dims[0].start.structurally_equal(factory.parse_index(4))
assert dims[0].stop.structurally_equal(factory.parse_index(5))
assert dims[1].start.structurally_equal(
PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-3)
)
assert dims[1].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-2)
)
def test_slices_with_negative_start():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx")
ctx.add_field(archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
islice = (slice(-3, -1, 1), slice(-4, None, 1))
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
dims = ispace.dimensions
assert dims[0].start.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-3)
)
assert dims[1].start.structurally_equal(
PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-4)
)
def test_negative_singular_slices():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
archetype_field = Field.create_generic("f", spatial_dimensions=2, layout="fzyx")
ctx.add_field(archetype_field)
archetype_arr = ctx.get_buffer(archetype_field)
islice = (-2, -1)
ispace = FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
dims = ispace.dimensions
assert dims[0].start.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-2)
)
assert dims[0].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[0]) + factory.parse_index(-1)
)
assert dims[1].start.structurally_equal(
PsExpression.make(archetype_arr.shape[1]) + factory.parse_index(-1)
)
assert dims[1].stop.structurally_equal(
PsExpression.make(archetype_arr.shape[1])
)
def test_field_independent_slices():
ctx = KernelCreationContext()
islice = (slice(-3, -1, 1), slice(-4, 7, 2))
ispace = FullIterationSpace.create_from_slice(ctx, islice)
dims = ispace.dimensions
for sl, dim in zip(islice, dims):
assert isinstance(dim.start, PsConstantExpr)
assert dim.start.constant.value == np.int64(sl.start)
assert isinstance(dim.stop, PsConstantExpr)
assert dim.stop.constant.value == np.int64(sl.stop)
assert isinstance(dim.step, PsConstantExpr)
assert dim.step.constant.value == np.int64(sl.step)
def test_invalid_slices():
ctx = KernelCreationContext()
archetype_field = Field.create_generic("f", spatial_dimensions=1, layout="fzyx")
ctx.add_field(archetype_field)
islice = (slice(1, -1, 0.5),)
with pytest.raises(TypeError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),)
with pytest.raises(TypificationError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
islice = (slice(1, 3, 0),)
with pytest.raises(ValueError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
islice = (slice(1, 3, -1),)
with pytest.raises(ValueError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
def test_iteration_count():
ctx = KernelCreationContext()
i, j, k = [PsExpression.make(ctx.get_symbol(x, ctx.index_dtype)) for x in "ijk"]
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
three = PsExpression.make(PsConstant(3, ctx.index_dtype))
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[three : i - two, 1:8:3]
)
iters = [ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal((i - two) - three)
assert iters[1].structurally_equal(three)
empty_ispace = FullIterationSpace.create_from_slice(ctx, make_slice[4:4:1, 4:4:7])
iters = [empty_ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal(zero)
assert iters[1].structurally_equal(zero)
import pytest
from pystencils import (
fields,
Assignment,
create_kernel,
CreateKernelConfig,
Target,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsLoop, PsPragma
@pytest.mark.parametrize("nesting_depth", range(3))
@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"])
@pytest.mark.parametrize("collapse", [None, 1, 2])
@pytest.mark.parametrize("omit_parallel_construct", range(3))
def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct):
f, g = fields("f, g: [3D]")
asm = Assignment(f.center(0), g.center(0))
gen_config = CreateKernelConfig(target=Target.CPU)
gen_config.cpu.openmp.enable = True
gen_config.cpu.openmp.nesting_depth = nesting_depth
gen_config.cpu.openmp.schedule = schedule
gen_config.cpu.openmp.collapse = collapse
gen_config.cpu.openmp.omit_parallel_construct = omit_parallel_construct
kernel = create_kernel(asm, gen_config)
ast = kernel.body
def find_omp_pragma(ast) -> PsPragma:
num_loops = 0
generator = dfs_preorder(ast)
for node in generator:
match node:
case PsLoop():
num_loops += 1
case PsPragma():
loop = next(generator)
assert isinstance(loop, PsLoop)
assert num_loops == nesting_depth
return node
pytest.fail("No OpenMP pragma found")
pragma = find_omp_pragma(ast)
tokens = set(pragma.text.split())
expected_tokens = {"omp", "for", f"schedule({schedule})"}
if not omit_parallel_construct:
expected_tokens.add("parallel")
if collapse is not None:
expected_tokens.add(f"collapse({collapse})")
assert tokens == expected_tokens
import pytest
import sympy as sp
import numpy as np
from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
from pystencils.sympyextensions.pointers import mem_acc
from pystencils.backend.ast.structural import (
PsDeclaration,
PsAssignment,
PsExpression,
PsConditional,
PsBlock,
)
from pystencils.backend.ast.expressions import (
PsArrayInitList,
PsCast,
PsConstantExpr,
PsSymbolExpr,
PsSubscript,
PsBinOp,
PsAnd,
PsOr,
PsNot,
PsEq,
PsNe,
PsGe,
PsLe,
PsGt,
PsLt,
PsCall,
PsTernary,
PsMemAcc
)
from pystencils.backend.ast.vector import PsVecBroadcast
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
from pystencils.types import constify, create_type, create_numeric_type, PsVectorType
from pystencils.types.quick import Fp, Int, Bool, Arr, Ptr
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
from pystencils.sympyextensions.integer_functions import (
bit_shift_left,
bit_shift_right,
bitwise_and,
bitwise_xor,
bitwise_or,
)
def test_typify_simple():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
fasm = typify(fasm)
assert isinstance(fasm, PsDeclaration)
def check(expr):
assert expr.dtype == ctx.default_dtype
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(fasm.lhs)
check(fasm.rhs)
def test_lhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Can assign to non-const LHS
asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
assert not asm.lhs.get_dtype().const
# Cannot assign to const left-hand side
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
struct_type = constify(create_type(np_struct))
struct_field = Field.create_generic(
"struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
)
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
# Const LHS is only OK in declarations
q = ctx.get_symbol("q", Fp(32, const=True))
ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
ast = typify(ast)
assert ast.lhs.dtype == Fp(32)
ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
with pytest.raises(TypificationError):
typify(ast)
def test_typify_structs():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
x = sp.Symbol("x")
# Good
asm = Assignment(x, f.absolute_access((0,), "data"))
fasm = freeze(asm)
fasm = typify(fasm)
asm = Assignment(f.absolute_access((0,), "data"), x)
fasm = freeze(asm)
fasm = typify(fasm)
# Bad
asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm)
with pytest.raises(TypificationError):
fasm = typify(fasm)
def test_default_typing():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
expr = freeze(2 * x + 3 * y + z - 4)
expr = typify(expr)
def check(expr):
assert expr.dtype == ctx.default_dtype
match expr:
case PsConstantExpr(cs):
assert cs.value in (2, 3, -4)
assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_inline_arrays_1d():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x = sp.Symbol("x")
y = TypedSymbol("y", Fp(16))
idx = TypedSymbol("idx", Int(32))
arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, (freeze(idx),)))
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16)
assert decl.rhs.dtype == Fp(16)
assert arr.dtype == Arr(Fp(16), (4,))
for item in arr.items:
assert item.dtype == Fp(16)
def test_inline_arrays_3d():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x = sp.Symbol("x")
y = TypedSymbol("y", Fp(16))
idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
arr: PsArrayInitList = freeze(
sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10)))
)
decl = PsDeclaration(
freeze(x),
freeze(y) + PsSubscript(arr, (freeze(idx[0]), freeze(idx[1]), freeze(idx[2]))),
)
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16)
assert decl.rhs.dtype == Fp(16)
assert arr.dtype == Arr(Fp(16), (2, 3, 2))
assert arr.shape == (2, 3, 2)
for item in arr.items:
assert item.dtype == Fp(16)
def test_array_subscript():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
arr = sp.IndexedBase(TypedSymbol("arr", Arr(Fp(32), (16,))))
expr = freeze(arr[3])
expr = typify(expr)
assert expr.dtype == Fp(32)
arr = sp.IndexedBase(TypedSymbol("arr2", Arr(Fp(32), (7, 31))))
expr = freeze(arr[3, 5])
expr = typify(expr)
assert expr.dtype == Fp(32)
def test_invalid_subscript():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
non_arr = sp.IndexedBase(TypedSymbol("non_arr", Int(64)))
expr = freeze(non_arr[3])
with pytest.raises(TypificationError):
expr = typify(expr)
wrong_shape_arr = sp.IndexedBase(
TypedSymbol("wrong_shape_arr", Arr(Fp(32), (7, 31, 5)))
)
expr = freeze(wrong_shape_arr[3, 5])
with pytest.raises(TypificationError):
expr = typify(expr)
# raw pointers are not arrays, cannot enter subscript
ptr = sp.IndexedBase(
TypedSymbol("ptr", Ptr(Int(16)))
)
expr = freeze(ptr[37])
with pytest.raises(TypificationError):
expr = typify(expr)
def test_mem_acc():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ptr = TypedSymbol("ptr", Ptr(Int(64)))
idx = TypedSymbol("idx", Int(32))
expr = freeze(mem_acc(ptr, idx))
expr = typify(expr)
assert isinstance(expr, PsMemAcc)
assert expr.dtype == Int(64)
assert expr.offset.dtype == Int(32)
def test_invalid_mem_acc():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
non_ptr = TypedSymbol("non_ptr", Int(64))
idx = TypedSymbol("idx", Int(32))
expr = freeze(mem_acc(non_ptr, idx))
with pytest.raises(TypificationError):
_ = typify(expr)
arr = TypedSymbol("arr", Arr(Int(64), (31,)))
idx = TypedSymbol("idx", Int(32))
expr = freeze(mem_acc(arr, idx))
with pytest.raises(TypificationError):
_ = typify(expr)
def test_lhs_inference():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
# Type of the LHS is propagated to untyped RHS symbols
asm = Assignment(x, 3 - q)
fasm = typify(freeze(asm))
assert ctx.get_symbol("x").dtype == Fp(32)
assert fasm.lhs.dtype == Fp(32)
asm = Assignment(y, 3 - w)
fasm = typify(freeze(asm))
assert ctx.get_symbol("y").dtype == Fp(16)
assert fasm.lhs.dtype == Fp(16)
fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
fasm = typify(fasm)
assert ctx.get_symbol("z").dtype == Fp(16)
assert fasm.lhs.dtype == Fp(16)
fasm = PsDeclaration(
PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
)
fasm = typify(fasm)
assert ctx.get_symbol("r").dtype == Bool()
assert fasm.lhs.dtype == Bool()
assert fasm.rhs.dtype == Bool()
def test_array_declarations():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
# Array type fallback to default
arr1 = sp.Symbol("arr1")
decl = freeze(Assignment(arr1, sp.Tuple(1, 2, 3, 4)))
decl = typify(decl)
assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,))
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,))
# Array type determined by default-typed symbol
arr2 = sp.Symbol("arr2")
decl = freeze(Assignment(arr2, sp.Tuple((x, y, -7), (3, -2, 51))))
decl = typify(decl)
assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3))
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3))
# Array type determined by pre-typed symbol
q = TypedSymbol("q", Fp(16))
arr3 = sp.Symbol("arr3")
decl = freeze(Assignment(arr3, sp.Tuple((q, 2), (-q, 0.123))))
decl = typify(decl)
assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2))
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2))
# Array type determined by LHS symbol
arr4 = TypedSymbol("arr4", Arr(Int(16), 4))
decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2)))
decl = typify(decl)
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4)
def test_erronous_typing():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
expr = freeze(2 * x + 3 * y + q - 4)
with pytest.raises(TypificationError):
typify(expr)
# Conflict between LHS and RHS symbols
asm = Assignment(q, 3 - w)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
# Do not propagate types back from LHS symbols to RHS symbols
asm = Assignment(q, 3 - x)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
asm = AddAugmentedAssignment(z, 3 - q)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
def test_invalid_indices():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
typify = Typifier(ctx)
arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64), (61,))))
x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
# Using default-typed symbols as array indices is illegal when the default type is a float
fasm = PsAssignment(PsSubscript(arr, (x + y,)), z)
with pytest.raises(TypificationError):
typify(fasm)
fasm = PsAssignment(z, PsSubscript(arr, (x + y,)))
with pytest.raises(TypificationError):
typify(fasm)
def test_typify_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("x", ctx.index_dtype)
ctx.get_symbol("y", ctx.index_dtype)
x, y = sp.symbols("x, y")
expr = bit_shift_left(
bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
)
expr = freeze(expr)
expr = typify(expr)
def check(expr):
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.index_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.index_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_typify_integer_binops_floating_arg():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x = sp.Symbol("x")
expr = bit_shift_left(x, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_integer_binops_in_floating_context():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("i", ctx.index_dtype)
x, i = sp.symbols("x, i")
expr = x + bit_shift_left(i, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_constant_clones():
ctx = KernelCreationContext(default_dtype=Fp(32))
typify = Typifier(ctx)
c = PsConstantExpr(PsConstant(3.0))
x = PsSymbolExpr(ctx.get_symbol("x"))
expr = c + x
expr_clone = expr.clone()
expr = typify(expr)
assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
def test_typify_bools_and_relations():
ctx = KernelCreationContext()
typify = Typifier(ctx)
true = PsConstantExpr(PsConstant(True, Bool()))
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q))))
expr = typify(expr)
assert expr.dtype == Bool()
def test_bool_in_numerical_context():
ctx = KernelCreationContext()
typify = Typifier(ctx)
true = PsConstantExpr(PsConstant(True, Bool()))
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
expr = true + (p - q)
with pytest.raises(TypificationError):
typify(expr)
@pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe])
def test_typify_conditionals(rel):
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
cond = PsConditional(rel(x, y), PsBlock([]))
cond = typify(cond)
assert cond.condition.dtype == Bool()
def test_invalid_conditions():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
cond = PsConditional(x + y, PsBlock([]))
with pytest.raises(TypificationError):
typify(cond)
def test_typify_ternary():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"]
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
expr = PsTernary(p, x, y)
expr = typify(expr)
assert expr.dtype == Fp(32)
expr = PsTernary(PsAnd(p, q), a, b + a)
expr = typify(expr)
assert expr.dtype == Int(32)
expr = PsTernary(PsAnd(p, q), a, x)
with pytest.raises(TypificationError):
typify(expr)
expr = PsTernary(y, a, b)
with pytest.raises(TypificationError):
typify(expr)
def test_cfunction():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"]
def _threeway(x: np.float32, y: np.float32) -> np.int32:
assert False
threeway = CFunction.parse(_threeway)
result = typify(PsCall(threeway, [x, y]))
assert result.get_dtype() == Int(32)
assert result.args[0].get_dtype() == Fp(32)
assert result.args[1].get_dtype() == Fp(32)
with pytest.raises(TypificationError):
_ = typify(PsCall(threeway, (x, p)))
def test_typify_integer_vectors():
ctx = KernelCreationContext()
typify = Typifier(ctx)
a, b, c = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Int(32), 4))) for name in "abc"]
d, e = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "de"]
result = typify(a + (b / c) - a * c)
assert result.get_dtype() == PsVectorType(Int(32), 4)
result = typify(PsVecBroadcast(4, d - e) - PsVecBroadcast(4, e / d))
assert result.get_dtype() == PsVectorType(Int(32), 4)
def test_typify_bool_vectors():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Fp(32), 4))) for name in "xy"]
p, q = [PsExpression.make(ctx.get_symbol(name, PsVectorType(Bool(), 4))) for name in "pq"]
result = typify(PsAnd(PsOr(p, q), p))
assert result.get_dtype() == PsVectorType(Bool(), 4)
result = typify(PsAnd(PsLt(x, y), PsGe(y, x)))
assert result.get_dtype() == PsVectorType(Bool(), 4)
def test_inference_fails():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x = PsExpression.make(PsConstant(42))
with pytest.raises(TypificationError):
typify(PsEq(x, x))
with pytest.raises(TypificationError):
typify(PsArrayInitList([x]))
with pytest.raises(TypificationError):
typify(PsCast(ctx.default_dtype, x))
import pytest
from pystencils import create_type
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
Typifier,
)
from pystencils.backend.memory import BufferBasePtr
from pystencils.backend.constants import PsConstant
from pystencils.backend.ast.expressions import (
PsExpression,
PsCast,
PsMemAcc,
PsArrayInitList,
PsSubscript,
PsBufferAcc,
PsSymbolExpr,
PsLe,
PsGe,
PsAnd,
)
from pystencils.backend.ast.structural import (
PsStatement,
PsAssignment,
PsDeclaration,
PsBlock,
PsConditional,
PsComment,
PsPragma,
PsLoop,
)
from pystencils.types.quick import Fp, Ptr, Bool
def test_cloning():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y, z, m = [PsExpression.make(ctx.get_symbol(name)) for name in "xyzm"]
q = PsExpression.make(ctx.get_symbol("q", create_type("bool")))
a, b, c = [
PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc"
]
c1 = PsExpression.make(PsConstant(3.0))
c2 = PsExpression.make(PsConstant(-1.0))
one_f = PsExpression.make(PsConstant(1.0))
one_i = PsExpression.make(PsConstant(1))
def check(orig, clone):
assert not (orig is clone)
assert type(orig) is type(clone)
assert orig.structurally_equal(clone)
if isinstance(orig, PsExpression):
# Regression: Expression data types used to not be cloned
assert orig.dtype == clone.dtype
for c1, c2 in zip(orig.children, clone.children, strict=True):
check(c1, c2)
for ast in [
x,
y,
c1,
x + y,
x / y + c1,
c1 + c2,
PsStatement(x * y * z + c1),
PsAssignment(y, x / c1),
PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]),
PsConditional(
q, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
),
PsDeclaration(m, PsArrayInitList([[x, y, one_f + x], [one_f, c2, z]])),
PsPragma("omp parallel for"),
PsLoop(
a,
b,
c,
one_i,
PsBlock(
[
PsComment("Loop body"),
PsAssignment(x, y),
PsAssignment(x, y),
PsPragma("#pragma clang loop vectorize(enable)"),
PsStatement(
PsMemAcc(PsCast(Ptr(Fp(32)), z), one_i)
+ PsCast(
Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i))
)
),
]
),
),
]:
ast = typify(ast)
ast_clone = ast.clone()
check(ast, ast_clone)
def test_buffer_acc():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
from pystencils import fields
f, g = fields("f, g(3): [2D]")
a, b = [ctx.get_symbol(n, ctx.index_dtype) for n in "ab"]
f_buf = ctx.get_buffer(f)
f_acc = PsBufferAcc(
f_buf.base_pointer,
[PsExpression.make(i) for i in (a, b)] + [factory.parse_index(0)],
)
assert f_acc.buffer == f_buf
assert f_acc.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer))
f_acc_clone = f_acc.clone()
assert f_acc_clone is not f_acc
assert f_acc_clone.buffer == f_buf
assert f_acc_clone.base_pointer.structurally_equal(PsSymbolExpr(f_buf.base_pointer))
assert len(f_acc_clone.index) == 3
assert f_acc_clone.index[0].structurally_equal(PsSymbolExpr(ctx.get_symbol("a")))
assert f_acc_clone.index[1].structurally_equal(PsSymbolExpr(ctx.get_symbol("b")))
g_buf = ctx.get_buffer(g)
g_acc = PsBufferAcc(
g_buf.base_pointer,
[PsExpression.make(i) for i in (a, b)] + [factory.parse_index(2)],
)
assert g_acc.buffer == g_buf
assert g_acc.base_pointer.structurally_equal(PsSymbolExpr(g_buf.base_pointer))
second_bptr = PsExpression.make(
ctx.get_symbol("data_g_interior", g_buf.base_pointer.dtype)
)
second_bptr.symbol.add_property(BufferBasePtr(g_buf))
g_acc.base_pointer = second_bptr
assert g_acc.base_pointer == second_bptr
assert g_acc.buffer == g_buf
# cannot change base pointer to different buffer
with pytest.raises(ValueError):
g_acc.base_pointer = PsExpression.make(f_buf.base_pointer)
from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.memory import PsSymbol
from pystencils.backend.constants import PsConstant
from pystencils.types.quick import Fp, SInt, UInt, Bool
from pystencils.backend.emission import CAstPrinter
def test_arithmetic_precedence():
(a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"]
cprint = CAstPrinter()
expr = (a + b) + (c + d)
code = cprint(expr)
assert code == "a + b + (c + d)"
expr = ((a + b) + c) + d
code = cprint(expr)
assert code == "a + b + c + d"
expr = a + (b + (c + d))
code = cprint(expr)
assert code == "a + (b + (c + d))"
expr = a - (b - c) - d
code = cprint(expr)
assert code == "a - (b - c) - d"
expr = a + b * (c + d * (e + f))
code = cprint(expr)
assert code == "a + b * (c + d * (e + f))"
expr = (-a) + b + (-c) + -(e + f)
code = cprint(expr)
assert code == "-a + b + -c + -(e + f)"
expr = (a / b) + (c / (d + e) * f)
code = cprint(expr)
assert code == "a / b + c / (d + e) * f"
def test_printing_integer_functions():
(i, j, k) = [PsExpression.make(PsSymbol(x, UInt(64))) for x in "ijk"]
cprint = CAstPrinter()
from pystencils.backend.ast.expressions import (
PsLeftShift,
PsRightShift,
PsBitwiseAnd,
PsBitwiseOr,
PsBitwiseXor,
PsIntDiv,
PsRem,
)
expr = PsBitwiseAnd(
PsBitwiseXor(
PsBitwiseXor(j, k),
PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)),
)
+ PsRem(i, k),
i,
)
code = cprint(expr)
assert code == "(j ^ k ^ (i << (j >> k) | i / k)) + i % k & i"
def test_logical_precedence():
from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr
p, q, r = [PsExpression.make(PsSymbol(x, Bool())) for x in "pqr"]
true = PsExpression.make(PsConstant(True, Bool()))
false = PsExpression.make(PsConstant(False, Bool()))
cprint = CAstPrinter()
expr = PsNot(PsAnd(p, PsOr(q, r)))
code = cprint(expr)
assert code == "!(p && (q || r))"
expr = PsAnd(PsAnd(p, q), PsAnd(q, r))
code = cprint(expr)
assert code == "p && q && (q && r)"
expr = PsOr(PsAnd(true, p), PsOr(PsAnd(false, PsNot(q)), PsAnd(r, p)))
code = cprint(expr)
assert code == "true && p || (false && !q || r && p)"
expr = PsAnd(PsOr(PsNot(p), PsNot(q)), PsNot(PsOr(true, false)))
code = cprint(expr)
assert code == "(!p || !q) && !(true || false)"
def test_relations_precedence():
from pystencils.backend.ast.expressions import (
PsNot,
PsAnd,
PsOr,
PsEq,
PsNe,
PsLt,
PsGt,
PsLe,
PsGe,
)
x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"]
cprint = CAstPrinter()
expr = PsAnd(PsEq(x, y), PsLe(y, z))
code = cprint(expr)
assert code == "x == y && y <= z"
expr = PsOr(PsLt(x, y), PsLt(y, z))
code = cprint(expr)
assert code == "x < y || y < z"
expr = PsAnd(PsNot(PsGe(x, y)), PsNot(PsLe(y, z)))
code = cprint(expr)
assert code == "!(x >= y) && !(y <= z)"
expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z)))
code = cprint(expr)
assert code == "x != y || !(y > z)"
def test_ternary():
from pystencils.backend.ast.expressions import PsTernary
from pystencils.backend.ast.expressions import PsAnd, PsOr
p, q = [PsExpression.make(PsSymbol(x, Bool())) for x in "pq"]
x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"]
cprint = CAstPrinter()
expr = PsTernary(p, x, y)
code = cprint(expr)
assert code == "p ? x : y"
expr = PsTernary(PsAnd(p, q), x + y, z)
code = cprint(expr)
assert code == "p && q ? x + y : z"
expr = PsTernary(p, PsTernary(q, x, y), z)
code = cprint(expr)
assert code == "p ? (q ? x : y) : z"
expr = PsTernary(p, x, PsTernary(q, y, z))
code = cprint(expr)
assert code == "p ? x : q ? y : z"
expr = PsTernary(PsTernary(p, q, PsOr(p, q)), x, y)
code = cprint(expr)
assert code == "(p ? q : p || q) ? x : y"
def test_arrays():
import sympy as sp
from pystencils import Assignment
from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory
ctx = KernelCreationContext(default_dtype=SInt(32))
factory = AstFactory(ctx)
cprint = CAstPrinter()
arr_1d = factory.parse_sympy(Assignment(sp.Symbol("a1d"), sp.Tuple(1, 2, 3, 4, 5)))
code = cprint(arr_1d)
assert code == "int32_t a1d[5] = { 1, 2, 3, 4, 5 };"
arr_2d = factory.parse_sympy(
Assignment(sp.Symbol("a2d"), sp.Tuple((1, -1), (2, -2)))
)
code = cprint(arr_2d)
assert code == "int32_t a2d[2][2] = { { 1, -1 }, { 2, -2 } };"
arr_3d = factory.parse_sympy(
Assignment(sp.Symbol("a3d"), sp.Tuple(((1, -1), (2, -2)), ((3, -3), (4, -4))))
)
code = cprint(arr_3d)
assert (
code
== "int32_t a3d[2][2][2] = { { { 1, -1 }, { 2, -2 } }, { { 3, -3 }, { 4, -4 } } };"
)
import numpy as np
import pytest
from pystencils.types import PsTypeError
from pystencils.backend.constants import PsConstant
from pystencils.types.quick import Fp, Bool, UInt, SInt
from pystencils.backend.exceptions import PsInternalCompilerError
def test_constant_equality():
c1 = PsConstant(1.0, Fp(32))
c2 = PsConstant(1.0, Fp(32))
assert c1 == c2
assert hash(c1) == hash(c2)
c3 = PsConstant(1.0, Fp(64))
assert c1 != c3
assert hash(c1) != hash(c3)
c4 = c1.reinterpret_as(Fp(64))
assert c4 != c1
assert c4 == c3
def test_interpret():
c1 = PsConstant(3.4, Fp(32))
c2 = PsConstant(3.4)
assert c2.interpret_as(Fp(32)) == c1
with pytest.raises(PsInternalCompilerError):
_ = c1.interpret_as(Fp(64))
def test_boolean_constants():
true = PsConstant(True, Bool())
for val in (1, 1.0, True, np.True_):
assert PsConstant(val, Bool()) == true
false = PsConstant(False, Bool())
for val in (0, 0.0, False, np.False_):
assert PsConstant(val, Bool()) == false
with pytest.raises(PsTypeError):
PsConstant(1.1, Bool())
def test_integer_bounds():
# should not throw:
for val in (255, np.uint8(255), np.int16(255), np.int64(255)):
_ = PsConstant(val, UInt(8))
for val in (-128, np.int16(-128), np.int64(-128)):
_ = PsConstant(val, SInt(8))
# should throw:
for val in (256, np.int16(256), np.int64(256)):
with pytest.raises(PsTypeError):
_ = PsConstant(val, UInt(8))
for val in (-42, np.int32(-42)):
with pytest.raises(PsTypeError):
_ = PsConstant(val, UInt(8))
for val in (-129, np.int16(-129), np.int64(-129)):
with pytest.raises(PsTypeError):
_ = PsConstant(val, SInt(8))
def test_floating_bounds():
for val in (5.1e4, -5.9e4):
_ = PsConstant(val, Fp(16))
_ = PsConstant(val, Fp(32))
_ = PsConstant(val, Fp(64))
for val in (8.1e5, -7.6e5):
with pytest.raises(PsTypeError):
_ = PsConstant(val, Fp(16))
import pytest
from pystencils import Target, Kernel
# from pystencils.backend.constraints import PsKernelParamsConstraint
from pystencils.backend.memory import PsSymbol, PsBuffer
from pystencils.backend.constants import PsConstant
from pystencils.backend.ast.expressions import PsBufferAcc, PsExpression
from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop
from pystencils.types.quick import SInt, Fp
from pystencils.jit import LegacyCpuJit
import numpy as np
@pytest.mark.xfail(reason="Fails until constraints are reimplemented")
def test_pairwise_addition():
idx_type = SInt(64)
u = PsBuffer("u", Fp(64, const=True), (...,), (...,), index_dtype=idx_type)
v = PsBuffer("v", Fp(64), (...,), (...,), index_dtype=idx_type)
u_data = PsArrayBasePointer("u_data", u)
v_data = PsArrayBasePointer("v_data", v)
loop_ctr = PsExpression.make(PsSymbol("ctr", idx_type))
zero = PsExpression.make(PsConstant(0, idx_type))
one = PsExpression.make(PsConstant(1, idx_type))
two = PsExpression.make(PsConstant(2, idx_type))
update = PsAssignment(
PsBufferAcc(v_data, loop_ctr),
PsBufferAcc(u_data, two * loop_ctr) + PsBufferAcc(u_data, two * loop_ctr + one)
)
loop = PsLoop(
loop_ctr,
zero,
PsExpression.make(v.shape[0]),
one,
PsBlock([update])
)
func = Kernel(PsBlock([loop]), Target.CPU, "kernel", set())
# sizes_constraint = PsKernelParamsConstraint(
# u.shape[0].eq(2 * v.shape[0]),
# "Array `u` must have twice the length of array `v`"
# )
# func.add_constraints(sizes_constraint)
jit = LegacyCpuJit()
kernel = jit.compile(func)
# Positive case
N = 21
u_arr = np.arange(2 * N, dtype=np.float64)
v_arr = np.zeros((N,), dtype=np.float64)
assert u_arr.shape[0] == 2 * v_arr.shape[0]
kernel(u=u_arr, v=v_arr)
v_expected = np.zeros_like(v_arr)
for i in range(N):
v_expected[i] = u_arr[2 * i] + u_arr[2*i + 1]
np.testing.assert_allclose(v_arr, v_expected)
# Negative case - mismatched array sizes
u_arr = np.zeros((N + 2,), dtype=np.float64)
v_arr = np.zeros((N,), dtype=np.float64)
with pytest.raises(ValueError):
kernel(u=u_arr, v=v_arr)
# Negative case - mismatched types
u_arr = np.arange(2 * N, dtype=np.float64)
v_arr = np.zeros((N,), dtype=np.float32)
with pytest.raises(TypeError):
kernel(u=u_arr, v=v_arr)
import sympy as sp
from pystencils import make_slice, Field, Assignment
from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace
from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC
from pystencils.backend.literals import PsLiteral
from pystencils.backend.emission import CAstPrinter
from pystencils.backend.ast.expressions import PsExpression, PsSubscript
from pystencils.backend.ast.structural import PsBlock, PsDeclaration
from pystencils.types.quick import Arr, Int
def test_literals():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
f = Field.create_generic("f", 3)
x = sp.Symbol("x")
cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64), (3,), const=True)))
global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype))
loop_slice = make_slice[
0:PsSubscript(cells, (factory.parse_index(0),)),
0:PsSubscript(cells, (factory.parse_index(1),)),
0:PsSubscript(cells, (factory.parse_index(2),)),
]
ispace = FullIterationSpace.create_from_slice(ctx, loop_slice)
ctx.set_iteration_space(ispace)
x_decl = PsDeclaration(factory.parse_sympy(x), global_constant)
loop_body = PsBlock([
x_decl,
factory.parse_sympy(Assignment(f.center(), x))
])
loops = factory.loops_from_ispace(ispace, loop_body)
ast = PsBlock([loops])
canon = CanonicalizeSymbols(ctx)
ast = canon(ast)
hoist = HoistLoopInvariantDeclarations(ctx)
ast = hoist(ast)
lower = LowerToC(ctx)
ast = lower(ast)
assert isinstance(ast, PsBlock)
assert len(ast.statements) == 2
assert ast.statements[0] == x_decl
code = CAstPrinter()(ast)
print(code)
assert "const double x = C;" in code
assert "CELLS[0LL]" in code
assert "CELLS[1LL]" in code
assert "CELLS[2LL]" in code
import pytest
from dataclasses import dataclass
from pystencils.backend.memory import PsSymbol, PsSymbolProperty, UniqueSymbolProperty
def test_properties():
@dataclass(frozen=True)
class NumbersProperty(PsSymbolProperty):
n: int
x: float
@dataclass(frozen=True)
class StringProperty(PsSymbolProperty):
s: str
@dataclass(frozen=True)
class MyUniqueProperty(UniqueSymbolProperty):
val: int
s = PsSymbol("s")
assert not s.properties
s.add_property(NumbersProperty(42, 8.71))
assert s.properties == {NumbersProperty(42, 8.71)}
# no duplicates
s.add_property(NumbersProperty(42, 8.71))
assert s.properties == {NumbersProperty(42, 8.71)}
s.add_property(StringProperty("pystencils"))
assert s.properties == {NumbersProperty(42, 8.71), StringProperty("pystencils")}
assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)}
assert not s.get_properties(MyUniqueProperty)
s.add_property(MyUniqueProperty(13))
assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)}
# Adding the same one again does not raise
s.add_property(MyUniqueProperty(13))
assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)}
with pytest.raises(ValueError):
s.add_property(MyUniqueProperty(14))
s.remove_property(MyUniqueProperty(13))
assert not s.get_properties(MyUniqueProperty)
from pystencils.backend.ast.analysis import OperationCounter
from pystencils.backend.ast.expressions import (
PsAdd,
PsConstant,
PsDiv,
PsExpression,
PsMul,
PsTernary,
)
from pystencils.backend.ast.structural import (
PsBlock,
PsDeclaration,
PsLoop,
)
from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
from pystencils.types import PsBoolType
def test_count_operations():
ctx = KernelCreationContext()
typify = Typifier(ctx)
counter = OperationCounter()
x = PsExpression.make(ctx.get_symbol("x"))
y = PsExpression.make(ctx.get_symbol("y"))
z = PsExpression.make(ctx.get_symbol("z"))
i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))
p = PsExpression.make(ctx.get_symbol("p", PsBoolType()))
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
five = PsExpression.make(PsConstant(5, ctx.index_dtype))
ast = PsLoop(
i,
zero,
five,
two,
PsBlock(
[
PsDeclaration(x, PsAdd(y, z)),
PsDeclaration(y, PsMul(x, PsMul(y, z))),
PsDeclaration(z, PsDiv(PsDiv(PsDiv(x, y), z), PsTernary(p, x, y))),
]
),
)
ast = typify(ast)
op_count = counter(ast)
assert op_count.float_adds == 3 * 1
assert op_count.float_muls == 3 * 2
assert op_count.float_divs == 3 * 3
assert op_count.int_adds == 3 * 1
assert op_count.int_muls == 0
assert op_count.int_divs == 0
assert op_count.calls == 0
assert op_count.branches == 3 * 1
assert op_count.loops_with_dynamic_bounds == 0
import pytest
import sympy as sp
import numpy as np
from dataclasses import dataclass
from itertools import chain
from functools import partial
from typing import Callable
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.platforms import GenericVectorCpu, X86VectorArch, X86VectorCpu
from pystencils.backend.ast.structural import PsBlock
from pystencils.backend.transformations import (
LoopVectorizer,
SelectIntrinsics,
LowerToC,
)
from pystencils.backend.constants import PsConstant
from pystencils.codegen.driver import create_cpu_kernel_function
from pystencils.jit import LegacyCpuJit
from pystencils import Target, fields, Assignment, Field
from pystencils.field import create_numpy_array_with_layout
from pystencils.types import PsScalarType, PsIntegerType
from pystencils.types.quick import SInt, Fp
@dataclass
class VectorTestSetup:
target: Target
platform_factory: Callable[[KernelCreationContext], GenericVectorCpu]
lanes: int
numeric_dtype: PsScalarType
index_dtype: PsIntegerType
@property
def name(self) -> str:
return f"{self.target.name}/{self.numeric_dtype}<{self.lanes}>/{self.index_dtype}"
def get_setups(target: Target) -> list[VectorTestSetup]:
match target:
case Target.X86_SSE:
sse_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.SSE)
return [
VectorTestSetup(target, sse_platform, 4, Fp(32), SInt(32)),
VectorTestSetup(target, sse_platform, 2, Fp(64), SInt(64)),
]
case Target.X86_AVX:
avx_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX)
return [
VectorTestSetup(target, avx_platform, 4, Fp(32), SInt(32)),
VectorTestSetup(target, avx_platform, 8, Fp(32), SInt(32)),
VectorTestSetup(target, avx_platform, 2, Fp(64), SInt(64)),
VectorTestSetup(target, avx_platform, 4, Fp(64), SInt(64)),
]
case Target.X86_AVX512:
avx512_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX512)
return [
VectorTestSetup(target, avx512_platform, 4, Fp(32), SInt(32)),
VectorTestSetup(target, avx512_platform, 8, Fp(32), SInt(32)),
VectorTestSetup(target, avx512_platform, 16, Fp(32), SInt(32)),
VectorTestSetup(target, avx512_platform, 2, Fp(64), SInt(64)),
VectorTestSetup(target, avx512_platform, 4, Fp(64), SInt(64)),
VectorTestSetup(target, avx512_platform, 8, Fp(64), SInt(64)),
]
case Target.X86_AVX512_FP16:
avx512_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX512_FP16)
return [
VectorTestSetup(target, avx512_platform, 8, Fp(16), SInt(32)),
VectorTestSetup(target, avx512_platform, 16, Fp(16), SInt(32)),
VectorTestSetup(target, avx512_platform, 32, Fp(16), SInt(32)),
]
case _:
return []
TEST_SETUPS: list[VectorTestSetup] = list(
chain.from_iterable(get_setups(t) for t in Target.available_vector_cpu_targets())
)
TEST_IDS = [t.name for t in TEST_SETUPS]
@pytest.fixture(params=TEST_SETUPS, ids=TEST_IDS)
def vectorization_setup(request) -> VectorTestSetup:
return request.param
def create_vector_kernel(
assignments: list[Assignment],
field: Field,
setup: VectorTestSetup,
ghost_layers: int = 0,
):
ctx = KernelCreationContext(
default_dtype=setup.numeric_dtype, index_dtype=setup.index_dtype
)
platform = setup.platform_factory(ctx)
factory = AstFactory(ctx)
ispace = FullIterationSpace.create_with_ghost_layers(ctx, ghost_layers, field)
ctx.set_iteration_space(ispace)
body = PsBlock([factory.parse_sympy(asm) for asm in assignments])
loop_order = field.layout
loop_nest = factory.loops_from_ispace(ispace, body, loop_order)
for field in ctx.fields:
# Set inner strides to one to ensure packed memory access
buf = ctx.get_buffer(field)
buf.strides[0] = PsConstant(1, ctx.index_dtype)
vectorize = LoopVectorizer(ctx, setup.lanes)
loop_nest = vectorize.vectorize_select_loops(
loop_nest, lambda l: l.counter.symbol.name == "ctr_0"
)
select_intrin = SelectIntrinsics(ctx, platform)
loop_nest = select_intrin(loop_nest)
lower = LowerToC(ctx)
loop_nest = lower(loop_nest)
func = create_cpu_kernel_function(
ctx,
platform,
PsBlock([loop_nest]),
"vector_kernel",
Target.CPU,
LegacyCpuJit(),
)
kernel = func.compile()
return kernel
@pytest.mark.parametrize("ghost_layers", [0, 2])
def test_update_kernel(vectorization_setup: VectorTestSetup, ghost_layers: int):
setup = vectorization_setup
src, dst = fields(f"src(2), dst(4): {setup.numeric_dtype}[2D]", layout="fzyx")
x = sp.symbols("x_:4")
update = [
Assignment(x[0], src[0, 0](0) + src[0, 0](1)),
Assignment(x[1], src[0, 0](0) - src[0, 0](1)),
Assignment(x[2], src[0, 0](0) * src[0, 0](1)),
Assignment(x[3], src[0, 0](0) / src[0, 0](1)),
Assignment(dst.center(0), x[0]),
Assignment(dst.center(1), x[1]),
Assignment(dst.center(2), x[2]),
Assignment(dst.center(3), x[3]),
]
kernel = create_vector_kernel(update, src, setup, ghost_layers)
shape = (23, 17)
rgen = np.random.default_rng(seed=1648)
src_arr = create_numpy_array_with_layout(
shape + (2,), layout=(2, 1, 0), dtype=setup.numeric_dtype.numpy_dtype
)
rgen.random(dtype=setup.numeric_dtype.numpy_dtype, out=src_arr)
dst_arr = create_numpy_array_with_layout(
shape + (4,), layout=(2, 1, 0), dtype=setup.numeric_dtype.numpy_dtype
)
dst_arr[:] = 0.0
check_arr = np.zeros_like(dst_arr)
check_arr[:, :, 0] = src_arr[:, :, 0] + src_arr[:, :, 1]
check_arr[:, :, 1] = src_arr[:, :, 0] - src_arr[:, :, 1]
check_arr[:, :, 2] = src_arr[:, :, 0] * src_arr[:, :, 1]
check_arr[:, :, 3] = src_arr[:, :, 0] / src_arr[:, :, 1]
kernel(src=src_arr, dst=dst_arr)
resolution = np.finfo(setup.numeric_dtype.numpy_dtype).resolution
gls = ghost_layers
np.testing.assert_allclose(
dst_arr[gls:-gls, gls:-gls, :],
check_arr[gls:-gls, gls:-gls, :],
rtol=resolution,
)
if gls != 0:
for i in range(gls):
np.testing.assert_equal(dst_arr[i, :, :], 0.0)
np.testing.assert_equal(dst_arr[-i, :, :], 0.0)
np.testing.assert_equal(dst_arr[:, i, :], 0.0)
np.testing.assert_equal(dst_arr[:, -i, :], 0.0)
def test_trailing_iterations(vectorization_setup: VectorTestSetup):
setup = vectorization_setup
f = fields(f"f(1): {setup.numeric_dtype}[1D]", layout="fzyx")
update = [Assignment(f(0), 2 * f(0))]
kernel = create_vector_kernel(update, f, setup)
for trailing_iters in range(setup.lanes):
shape = (setup.lanes * 12 + trailing_iters, 1)
f_arr = create_numpy_array_with_layout(
shape, layout=(1, 0), dtype=setup.numeric_dtype.numpy_dtype
)
f_arr[:] = 1.0
kernel(f=f_arr)
np.testing.assert_equal(f_arr, 2.0)
def test_only_trailing_iterations(vectorization_setup: VectorTestSetup):
setup = vectorization_setup
f = fields(f"f(1): {setup.numeric_dtype}[1D]", layout="fzyx")
update = [Assignment(f(0), 2 * f(0))]
kernel = create_vector_kernel(update, f, setup)
for trailing_iters in range(1, setup.lanes):
shape = (trailing_iters, 1)
f_arr = create_numpy_array_with_layout(
shape, layout=(1, 0), dtype=setup.numeric_dtype.numpy_dtype
)
f_arr[:] = 1.0
kernel(f=f_arr)
np.testing.assert_equal(f_arr, 2.0)
import sympy as sp
from itertools import product
from pystencils import make_slice, fields, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop
from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma
def test_insert_pragmas():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
f, g = fields("f, g: [3D]")
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[:, :, :], archetype_field=f
)
ctx.set_iteration_space(ispace)
stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1]))
loop_body = PsBlock([
factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil))))
])
loops = factory.loops_from_ispace(ispace, loop_body)
pragmas = (
LoopPragma("omp parallel for", 0),
LoopPragma("some nonsense pragma", 1),
LoopPragma("omp simd", -1),
)
add_pragmas = InsertPragmasAtLoops(ctx, pragmas)
ast = add_pragmas(loops)
assert isinstance(ast, PsBlock)
first_pragma = ast.statements[0]
assert isinstance(first_pragma, PsPragma)
assert first_pragma.text == pragmas[0].text
assert ast.statements[1] == loops
second_pragma = loops.body.statements[0]
assert isinstance(second_pragma, PsPragma)
assert second_pragma.text == pragmas[1].text
second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1]
assert isinstance(second_loop, PsLoop)
third_pragma = second_loop.body.statements[0]
assert isinstance(third_pragma, PsPragma)
assert third_pragma.text == pragmas[2].text