Skip to content
Snippets Groups Projects
Commit aa6b64dc authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'simplify_equality' into 'master'

Added simplify_by_equality

See merge request pycodegen/pystencils!286
parents be198ac4 8c53c16a
No related merge requests found
...@@ -453,6 +453,72 @@ def recursive_collect(expr, symbols, order_by_occurences=False): ...@@ -453,6 +453,72 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
return rec_sum return rec_sum
def summands(expr):
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
def simplify_by_equality(expr, a, b, c):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
raise ValueError("a and b must be symbols.")
c = sp.sympify(c)
if not (isinstance(c, sp.Symbol) or is_constant(c)):
raise ValueError("c must be either a symbol or a constant!")
expr = sp.sympify(expr)
expr_expanded = sp.expand(expr)
a_coeff = expr_expanded.coeff(a, 1)
expr_expanded -= (a * a_coeff).expand()
b_coeff = expr_expanded.coeff(b, 1)
expr_expanded -= (b * b_coeff).expand()
if isinstance(c, sp.Symbol):
c_coeff = expr_expanded.coeff(c, 1)
rest = expr_expanded - (c * c_coeff).expand()
else:
c_coeff = expr_expanded / c
rest = 0
a_summands = summands(a_coeff)
b_summands = summands(b_coeff)
c_summands = summands(c_coeff)
# replace b + c by a
b_plus_c_coeffs = b_summands & c_summands
for coeff in b_plus_c_coeffs:
rest += a * coeff
b_summands -= b_plus_c_coeffs
c_summands -= b_plus_c_coeffs
# replace a - b by c
neg_b_summands = {-x for x in b_summands}
a_minus_b_coeffs = a_summands & neg_b_summands
for coeff in a_minus_b_coeffs:
rest += c * coeff
a_summands -= a_minus_b_coeffs
b_summands -= {-x for x in a_minus_b_coeffs}
# replace a - c by b
neg_c_summands = {-x for x in c_summands}
a_minus_c_coeffs = a_summands & neg_c_summands
for coeff in a_minus_c_coeffs:
rest += b * coeff
a_summands -= a_minus_c_coeffs
c_summands -= {-x for x in a_minus_c_coeffs}
# put it back together
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]], def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]: only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division. """Counts the number of additions, multiplications and division.
......
import sympy import sympy
import numpy as np import numpy as np
import sympy as sp
import pystencils import pystencils
from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import replace_second_order_products
from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import remove_higher_order_terms
from pystencils.sympyextensions import complete_the_squares_in_exp from pystencils.sympyextensions import complete_the_squares_in_exp
from pystencils.sympyextensions import extract_most_common_factor from pystencils.sympyextensions import extract_most_common_factor
from pystencils.sympyextensions import simplify_by_equality
from pystencils.sympyextensions import count_operations from pystencils.sympyextensions import count_operations
from pystencils.sympyextensions import common_denominator from pystencils.sympyextensions import common_denominator
from pystencils.sympyextensions import get_symmetric_part from pystencils.sympyextensions import get_symmetric_part
...@@ -176,3 +178,26 @@ def test_get_symmetric_part(): ...@@ -176,3 +178,26 @@ def test_get_symmetric_part():
sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))
assert sym_part == expected_result assert sym_part == expected_result
def test_simplify_by_equality():
x, y, z = sp.symbols('x, y, z')
p, q = sp.symbols('p, q')
# Let x = y + z
expr = x * p - y * p + z * q
expr = simplify_by_equality(expr, x, y, z)
assert expr == z * p + z * q
expr = x * (p - 2 * q) + 2 * q * z
expr = simplify_by_equality(expr, x, y, z)
assert expr == x * p - 2 * q * y
expr = x * (y + z) - y * z
expr = simplify_by_equality(expr, x, y, z)
assert expr == x*y + z**2
# Let x = y + 2
expr = x * p - 2 * p
expr = simplify_by_equality(expr, x, y, 2)
assert expr == y * p
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment