Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
import numpy as np
import pytest
from itertools import product
from pystencils import (
create_kernel,
Target,
Assignment,
Field,
)
from pystencils.sympyextensions.typed_sympy import CastFunc
AVAIL_TARGETS_NO_SSE = [t for t in Target.available_targets() if Target._SSE not in t]
target_and_dtype = pytest.mark.parametrize(
"target, from_type, to_type",
list(
product(
[
t
for t in AVAIL_TARGETS_NO_SSE
if Target._X86 in t and Target._AVX512 not in t
],
[np.int32, np.float32, np.float64],
[np.int32, np.float32, np.float64],
)
)
+ list(
product(
[
t
for t in AVAIL_TARGETS_NO_SSE
if Target._X86 not in t or Target._AVX512 in t
],
[np.int32, np.int64, np.float32, np.float64],
[np.int32, np.int64, np.float32, np.float64],
)
),
)
@target_and_dtype
def test_type_cast(gen_config, xp, from_type, to_type):
if np.issubdtype(from_type, np.floating):
inp = xp.array([-1.25, -0, 1.5, 3, -5, -312, 42, 6.625, -9], dtype=from_type)
else:
inp = xp.array([-1, 0, 1, 3, -5, -312, 42, 6, -9], dtype=from_type)
outp = xp.zeros_like(inp).astype(to_type)
truncated = inp.astype(to_type)
rounded = xp.round(inp).astype(to_type)
inp_field = Field.create_from_numpy_array("inp", inp)
outp_field = Field.create_from_numpy_array("outp", outp)
asms = [Assignment(outp_field.center(), CastFunc(inp_field.center(), to_type))]
kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile()
kfunc(inp=inp, outp=outp)
if np.issubdtype(from_type, np.floating) and not np.issubdtype(
to_type, np.floating
):
# rounding mode depends on platform
try:
xp.testing.assert_array_equal(outp, truncated)
except AssertionError:
xp.testing.assert_array_equal(outp, rounded)
else:
xp.testing.assert_array_equal(outp, truncated)
import sympy as sp
import pytest
from pystencils import Assignment, TypedSymbol, fields, FieldType
from pystencils import Assignment, TypedSymbol, fields, FieldType, make_slice
from pystencils.sympyextensions import CastFunc, mem_acc
from pystencils.sympyextensions.pointers import AddressOf
......@@ -18,19 +18,25 @@ from pystencils.backend.transformations import (
AstVectorizer,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsDeclaration, PsAssignment
from pystencils.backend.ast.structural import (
PsBlock,
PsDeclaration,
PsAssignment,
PsLoop,
)
from pystencils.backend.ast.expressions import (
PsSymbolExpr,
PsConstantExpr,
PsExpression,
PsCast,
PsMemAcc,
PsCall
PsCall,
PsSubscript,
)
from pystencils.backend.functions import CFunction
from pystencils.backend.ast.vector import PsVecBroadcast, PsVecMemAcc
from pystencils.backend.exceptions import VectorizationError
from pystencils.types import PsVectorType, deconstify, create_type
from pystencils.types import PsArrayType, PsVectorType, deconstify, create_type
def test_vectorize_expressions():
......@@ -56,7 +62,9 @@ def test_vectorize_expressions():
factory.parse_sympy(-x * y + 13 * z - 4 * (x / w) * (x + z)),
factory.parse_sympy(sp.sin(x + z) - sp.cos(w)),
factory.parse_sympy(y**2 - x**2),
typify(- factory.parse_sympy(x / (w**2))), # place the negation outside, since SymPy would remove it
typify(
-factory.parse_sympy(x / (w**2))
), # place the negation outside, since SymPy would remove it
factory.parse_sympy(13 + (1 / w) - sp.exp(x) * 24),
]:
vec_expr = vectorize.visit(expr, vc)
......@@ -239,7 +247,31 @@ def test_reject_symbol_assignments():
with pytest.raises(VectorizationError):
_ = vectorize.visit(asm, vc)
def test_vectorize_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
x, y = sp.symbols("x, y")
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
decl = PsDeclaration(factory.parse_sympy(x), factory.parse_sympy(sp.sympify(0)))
asm = PsAssignment(factory.parse_sympy(x), factory.parse_sympy(3 + y))
ast = PsBlock([decl, asm])
vec_ast = vectorize.visit(ast, vc)
vec_asm = vec_ast.statements[1]
assert isinstance(vec_asm, PsAssignment)
assert isinstance(vec_asm.lhs.symbol.dtype, PsVectorType)
def test_vectorize_memory_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
......@@ -260,7 +292,7 @@ def test_vectorize_memory_assignments():
asm = typify(
PsAssignment(
factory.parse_sympy(mem_acc(ptr, 3 * ctr + 2)),
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3))
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)),
)
)
......@@ -303,7 +335,7 @@ def test_invalid_memory_assignments():
asm = typify(
PsAssignment(
factory.parse_sympy(mem_acc(ptr, 3 * i + 2)),
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3))
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)),
)
)
......@@ -376,7 +408,9 @@ def test_vectorize_mem_acc():
assert vec_acc.vector_entries == 4
# Even more complex affine
idx = - factory.parse_index(ctr) / factory.parse_index(i) - factory.parse_index(ctr) * factory.parse_index(j)
idx = -factory.parse_index(ctr) / factory.parse_index(i) - factory.parse_index(
ctr
) * factory.parse_index(j)
acc = typify(PsMemAcc(factory.parse_sympy(ptr), idx))
assert isinstance(acc, PsMemAcc)
......@@ -386,11 +420,15 @@ def test_vectorize_mem_acc():
assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride.structurally_equal(factory.parse_index(-1) / factory.parse_index(i) - factory.parse_index(j))
assert vec_acc.stride.structurally_equal(
factory.parse_index(-1) / factory.parse_index(i) - factory.parse_index(j)
)
assert vec_acc.vector_entries == 4
# Mixture of strides in affine and axis
vc = VectorizationContext(ctx, 4, VectorizationAxis(ctx.get_symbol("ctr"), step=factory.parse_index(3)))
vc = VectorizationContext(
ctx, 4, VectorizationAxis(ctx.get_symbol("ctr"), step=factory.parse_index(3))
)
acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr))
assert isinstance(acc, PsMemAcc)
......@@ -421,7 +459,9 @@ def test_invalid_mem_acc():
ptr = TypedSymbol("ptr", create_type("float64 *"))
# Non-symbol pointer
acc = factory.parse_sympy(mem_acc(AddressOf(mem_acc(ptr, 10)), 3 * i + ctr * (3 + ctr)))
acc = factory.parse_sympy(
mem_acc(AddressOf(mem_acc(ptr, 10)), 3 * i + ctr * (3 + ctr))
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
......@@ -503,3 +543,107 @@ def test_invalid_buffer_acc():
with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc)
def test_vectorize_subscript():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
acc = PsSubscript(
PsExpression.make(ctx.get_symbol("arr", PsArrayType(ctx.default_dtype, 42))),
[PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))],
) # independent of vectorization axis
vec_acc = vectorize.visit(factory._typify(acc), vc)
assert isinstance(vec_acc, PsVecBroadcast)
assert isinstance(vec_acc.operand, PsSubscript)
def test_invalid_subscript():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
acc = PsSubscript(
PsExpression.make(ctx.get_symbol("arr", PsArrayType(ctx.default_dtype, 42))),
[PsExpression.make(ctr)], # depends on vectorization axis
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(factory._typify(acc), vc)
def test_vectorize_nested_loop():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("i", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
ast = factory.loop_nest(
("i", "j"),
make_slice[:8, :8], # inner loop does not depend on vectorization axis
PsBlock(
[
PsDeclaration(
PsExpression.make(ctx.get_symbol("x", ctx.default_dtype)),
PsExpression.make(PsConstant(42, ctx.default_dtype)),
)
]
),
)
vec_ast = vectorize.visit(ast, vc)
inner_loop = next(
dfs_preorder(
vec_ast,
lambda node: isinstance(node, PsLoop) and node.counter.symbol.name == "j",
)
)
decl = inner_loop.body.statements[0]
assert inner_loop.step.structurally_equal(
PsExpression.make(PsConstant(1, ctx.index_dtype))
)
assert isinstance(decl.lhs.symbol.dtype, PsVectorType)
def test_invalid_nested_loop():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("i", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
ast = factory.loop_nest(
("i", "j"),
make_slice[:8, :ctr], # inner loop depends on vectorization axis
PsBlock(
[
PsDeclaration(
PsExpression.make(ctx.get_symbol("x", ctx.default_dtype)),
PsExpression.make(PsConstant(42, ctx.default_dtype)),
)
]
),
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(ast, vc)