Skip to content
Snippets Groups Projects
Commit 8bc8b39a authored by Markus Holzer's avatar Markus Holzer
Browse files

added test cases for sympyextensions

parent fb85c0b7
No related branches found
No related tags found
1 merge request!168Extend testsuite
......@@ -272,7 +272,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ).
"""Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions
......@@ -293,7 +293,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if expr.is_Mul:
distinct_search_symbols = set()
nr_of_search_terms = 0
other_factors = 1
other_factors = sp.Integer(1)
for t in expr.args:
if t in search_symbols:
nr_of_search_terms += 1
......@@ -481,7 +481,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
pass
elif t.func is sp.Mul:
if check_type(t):
result['muls'] += len(t.args) - 1
result['muls'] += len(t.args)
for a in t.args:
if a == 1 or a == -1:
result['muls'] -= 1
......@@ -515,7 +515,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate")
......@@ -526,7 +526,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(t, sp.Rel):
pass
else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate")
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children:
for a in t.args:
......
import sympy
import pystencils
from pystencils.sympyextensions import replace_second_order_products
from pystencils.sympyextensions import remove_higher_order_terms
from pystencils.sympyextensions import complete_the_squares_in_exp
from pystencils.sympyextensions import extract_most_common_factor
from pystencils.sympyextensions import count_operations
from pystencils.sympyextensions import common_denominator
from pystencils.sympyextensions import get_symmetric_part
from pystencils import Assignment
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts)
def test_replace_second_order_products():
x, y = sympy.symbols('x y')
expr = 4 * x * y
expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2)
expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2)
result = replace_second_order_products(expr, search_symbols=[x, y], positive=True)
assert result == expected_expr_positive
assert (result - expected_expr_positive).simplify() == 0
result = replace_second_order_products(expr, search_symbols=[x, y], positive=False)
assert result == expected_expr_negative
assert (result - expected_expr_negative).simplify() == 0
result = replace_second_order_products(expr, search_symbols=[x, y], positive=None)
assert result == expected_expr_positive
a = [Assignment(sympy.symbols('z'), x + y)]
replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a)
assert len(a) == 2
def test_remove_higher_order_terms():
x, y = sympy.symbols('x y')
expr = sympy.Mul(x, y)
result = remove_higher_order_terms(expr, order=1, symbols=[x, y])
assert result == 0
result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
assert result == expr
expr = sympy.Pow(x, 3)
result = remove_higher_order_terms(expr, order=2, symbols=[x, y])
assert result == 0
result = remove_higher_order_terms(expr, order=3, symbols=[x, y])
assert result == expr
def test_complete_the_squares_in_exp():
a, b, c, s, n = sympy.symbols('a b c s n')
expr = a * s ** 2 + b * s + c
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expr
expr = sympy.exp(a * s ** 2 + b * s + c)
expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a))
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expected_result
def test_extract_most_common_factor():
x, y = sympy.symbols('x y')
expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y)
most_common_factor = extract_most_common_factor(expr)
assert most_common_factor[0] == 7
assert sympy.prod(most_common_factor) == expr
expr = 1 / x + 3 / (x + y) + 3 / y
most_common_factor = extract_most_common_factor(expr)
assert most_common_factor[0] == 3
assert sympy.prod(most_common_factor) == expr
expr = 1 / x
most_common_factor = extract_most_common_factor(expr)
assert most_common_factor[0] == 1
assert sympy.prod(most_common_factor) == expr
assert most_common_factor[1] == expr
def test_count_operations():
x, y, z = sympy.symbols('x y z')
expr = 1/x + y * sympy.sqrt(z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['muls'] == 1
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = sympy.sqrt(x + y)
expr = insert_fast_sqrts(expr).atoms(fast_sqrt)
ops = count_operations(*expr, only_type=None)
assert ops['fast_sqrts'] == 1
expr = sympy.sqrt(x / y)
expr = insert_fast_divisions(expr).atoms(fast_division)
ops = count_operations(*expr, only_type=None)
assert ops['fast_div'] == 1
expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y))
expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt)
ops = count_operations(*expr, only_type=None)
assert ops['fast_inv_sqrts'] == 1
expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['muls'] == 100
assert ops['divs'] == 1
assert ops['sqrts'] == 1
def test_common_denominator():
x = sympy.symbols('x')
expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3)
cm = common_denominator(expr)
assert cm == 6
def test_get_symmetric_part():
x, y, z = sympy.symbols('x y z')
expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3
expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3
sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))
assert sym_part == expected_result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment