Skip to content
Snippets Groups Projects
Commit fd4d1bc0 authored by Markus Holzer's avatar Markus Holzer
Browse files

Added test cases for pystencils simplifications

parent 8bc8b39a
No related branches found
No related tags found
1 merge request!168Extend testsuite
from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
apply_on_all_subexpressions, apply_to_all_assignments,
add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .simplificationstrategy import SimplificationStrategy
......@@ -10,4 +10,4 @@ __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions',
'add_subexpressions_for_field_reads']
'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads']
......@@ -18,7 +18,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif isinstance(e1, Node):
symbols = e1.symbols_defined
else:
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.")
raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols:
for c2, e2 in enumerate(assignments):
......@@ -112,14 +112,14 @@ def add_subexpressions_for_sums(ac):
addends = []
def contains_sum(term):
if term.func == sp.add.Add:
if term.func == sp.Add:
return True
if term.is_Atom:
return False
return any([contains_sum(a) for a in term.args])
def search_addends(term):
if term.func == sp.add.Add:
if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args)
for a in term.args:
......
import sympy as sp
from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads
from pystencils import Assignment, AssignmentCollection, fields
a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4")
f = sp.symbols("f_:9")
def test_subexpression_substitution_in_main_assignments():
subexpressions = [
Assignment(s0, 2 * a + 2 * b),
Assignment(s1, 2 * a + 2 * b + 2 * c),
Assignment(s2, 2 * a + 2 * b + 2 * c + 2 * d),
Assignment(s3, 2 * a + 2 * b * c),
Assignment(x, s1 + s2 + s0 + s3)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3),
Assignment(f[1], s1 + s2 + s0 + s3),
Assignment(f[2], s1 + s2 + s0 + s3),
Assignment(f[3], s1 + s2 + s0 + s3),
Assignment(f[4], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
ac = subexpression_substitution_in_main_assignments(ac)
for i in range(0, len(ac.main_assignments)):
assert ac.main_assignments[i].rhs == x
def test_add_subexpressions_for_divisions():
subexpressions = [
Assignment(s0, 2 / a + 2 / b),
Assignment(s1, 2 / a + 2 / b + 2 / c),
Assignment(s2, 2 / a + 2 / b + 2 / c + 2 / d),
Assignment(s3, 2 / a + 2 / b / c),
Assignment(x, s1 + s2 + s0 + s3)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
divs_before_optimisation = ac.operation_count["divs"]
ac = add_subexpressions_for_divisions(ac)
divs_after_optimisation = ac.operation_count["divs"]
assert divs_before_optimisation - divs_after_optimisation == 8
rhs = []
for i in range(len(ac.subexpressions)):
rhs.append(ac.subexpressions[i].rhs)
assert 1/a in rhs
assert 1/b in rhs
assert 1/c in rhs
assert 1/d in rhs
def test_add_subexpressions_for_sums():
subexpressions = [
Assignment(s0, a + b + c + d),
Assignment(s1, 3 * a * sp.sqrt(x) + 4 * b + c),
Assignment(s2, 3 * a * sp.sqrt(x) + 4 * b + c),
Assignment(s3, 3 * a * sp.sqrt(x) + 4 * b + c)
]
main = [
Assignment(f[0], s1 + s2 + s0 + s3)
]
ac = AssignmentCollection(main, subexpressions)
ops_before_optimisation = ac.operation_count
ac = add_subexpressions_for_sums(ac)
ops_after_optimisation = ac.operation_count
assert ops_after_optimisation["adds"] == ops_before_optimisation["adds"]
assert ops_after_optimisation["muls"] < ops_before_optimisation["muls"]
assert ops_after_optimisation["sqrts"] < ops_before_optimisation["sqrts"]
rhs = []
for i in range(len(ac.subexpressions)):
rhs.append(ac.subexpressions[i].rhs)
assert a + b + c + d in rhs
assert 3 * a * sp.sqrt(x) in rhs
def test_add_subexpressions_for_field_reads():
s, v = fields("s(5), v(5): double[2D]")
subexpressions = []
main = [
Assignment(s[0, 0](0), 3 * v[0, 0](0)),
Assignment(s[0, 0](1), 10 * v[0, 0](1))
]
ac = AssignmentCollection(main, subexpressions)
assert len(ac.subexpressions) == 0
ac = add_subexpressions_for_field_reads(ac)
assert len(ac.subexpressions) == 2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment