-
Markus Holzer authoredMarkus Holzer authored
test_simplifications.py 6.95 KiB
from sys import version_info as vs
import pytest
import pystencils.config
import sympy as sp
import pystencils as ps
from pystencils import Assignment, AssignmentCollection, fields
from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads
from pystencils.simp.simplifications import add_subexpressions_for_constants
from pystencils.typing import BasicType, TypedSymbol
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
f = sp.symbols("f_:9")
def test_subexpression_substitution_in_main_assignments():
subexpressions = [
Assignment(s0, 2 * a + 2 * b),
Assignment(s1, 2 * a + 2 * b + 2 * c),
Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
Assignment(s3, 2 * a + 2 * b * c),
Assignment(x, s1 + s2 + s0 + s3)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3),
Assignment(f[1], s1 + s2 + s0 + s3),
Assignment(f[2], s1 + s2 + s0 + s3),
Assignment(f[3], s1 + s2 + s0 + s3),
Assignment(f[4], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
ac = subexpression_substitution_in_main_assignments(ac)
for i in range(0, len(ac.main_assignments)):
assert ac.main_assignments[i].rhs == x
def test_add_subexpressions_for_divisions():
subexpressions = [
Assignment(s0, 2 / a + 2 / b),
Assignment(s1, 2 / a + 2 / b + 2 / c),
Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
Assignment(s3, 2 / a + 2 / b / c),
Assignment(x, s1 + s2 + s0 + s3)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
divs_before_optimisation = ac.operation_count["divs"]
ac = add_subexpressions_for_divisions(ac)
divs_after_optimisation = ac.operation_count["divs"]
assert divs_before_optimisation - divs_after_optimisation == 8
rhs = []
for i in range(len(ac.subexpressions)):
rhs.append(ac.subexpressions[i].rhs)
assert 1/a in rhs
assert 1/b in rhs
assert 1/c in rhs
assert 1/d in rhs
def test_add_subexpressions_for_constants():
half = sp.Rational(1,2)
sqrt_2 = sp.sqrt(2)
main = [
Assignment(f[0], half * a + half * b + half * c),
Assignment(f[1], - half * a - half * b),
Assignment(f[2], a * sqrt_2 - b * sqrt_2),
Assignment(f[3], a**2 + b**2)
]
ac = AssignmentCollection(main)
ac = add_subexpressions_for_constants(ac)
assert len(ac.subexpressions) == 2
half_subexp = None
sqrt_subexp = None
for asm in ac.subexpressions:
if asm.rhs == half:
half_subexp = asm.lhs
elif asm.rhs == sqrt_2:
sqrt_subexp = asm.lhs
else:
pytest.fail(f"An unexpected subexpression was encountered: {asm}")
assert half_subexp is not None
assert sqrt_subexp is not None
for asm in ac.main_assignments[:3]:
assert isinstance(asm.rhs, sp.Mul)
assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args)
assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args)
assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args)
# Do not replace exponents!
assert ac.main_assignments[3].rhs == a**2 + b**2
def test_add_subexpressions_for_sums():
subexpressions = [
Assignment(s0, a + b + c + d),
Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
ops_before_optimisation = ac.operation_count
ac = add_subexpressions_for_sums(ac)
ops_after_optimisation = ac.operation_count
assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]
rhs = []
for i in range(len(ac.subexpressions)):
rhs.append(ac.subexpressions[i].rhs)
assert a + b + c + d in rhs
assert 3 * a * sp.sqrt(x) in rhs
def test_add_subexpressions_for_field_reads():
s, v = fields("s(5), v(5): double[2D]")
subexpressions = []
main = [Assignment(s[0, 0](0), 3 * v[0, 0](0)),
Assignment(s[0, 0](1), 10 * v[0, 0](1))]
ac = AssignmentCollection(main, subexpressions)
assert len(ac.subexpressions) == 0
ac2 = add_subexpressions_for_field_reads(ac)
assert len(ac2.subexpressions) == 2
ac3 = add_subexpressions_for_field_reads(ac, data_type="float32")
assert len(ac3.subexpressions) == 2
assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol)
assert ac3.subexpressions[0].lhs.dtype == BasicType("float32")
# added check for early out of add_subexpressions_for_field_reads is no fields appear on the rhs (See #92)
main = [Assignment(s[0, 0](0), 3.0),
Assignment(s[0, 0](1), 4.0)]
ac4 = AssignmentCollection(main, subexpressions)
assert len(ac4.subexpressions) == 0
ac5 = add_subexpressions_for_field_reads(ac4)
assert ac5 is not None
assert ac4 is ac5
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
def test_sympy_optimizations(target, dtype):
if target == ps.Target.GPU:
pytest.importorskip("cupy")
src, dst = ps.fields(f'src, dst: {dtype}[2d]')
assignments = ps.AssignmentCollection({
src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
})
config = pystencils.config.CreateKernelConfig(target=target, default_number_float=dtype)
ast = ps.create_kernel(assignments, config=config)
ps.show_code(ast)
code = ps.get_code_str(ast)
if dtype == 'float32':
assert 'expf(' in code
elif dtype == 'float64':
assert 'exp(' in code
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.parametrize('simplification', (True, False))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
def test_evaluate_constant_terms(target, simplification):
if target == ps.Target.GPU:
pytest.importorskip("cupy")
src, dst = ps.fields('src, dst: float32[2d]')
# cos of a number will always be simplified
assignments = ps.AssignmentCollection({
src[0, 0]: -sp.cos(1) + dst[0, 0]
})
config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast)
assert 'cos(' not in code