Skip to content
Snippets Groups Projects

Added simplify_by_equality

Merged Frederik Hennig requested to merge da15siwa/pystencils:simplify_equality into master
Files
2
@@ -453,6 +453,67 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
@@ -453,6 +453,67 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
return rec_sum
return rec_sum
 
def summands(expr):
 
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
 
 
 
def simplify_by_equality(expr, a, b, c):
 
"""
 
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
 
by attempting to express additive combinations of two quantities by the third.
 
 
This works on expressions that are reducible to the form
 
:math:`a * (...) + b * (...) + c * (...)`,
 
without any mixed terms of a, b and c.
 
"""
 
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
 
raise ValueError("a and b must be symbols.")
 
 
expr_expanded = expr.expand()
 
a_coeff = expr_expanded.coeff(a, 1)
 
expr_expanded -= (a * a_coeff).expand()
 
b_coeff = expr_expanded.coeff(b, 1)
 
expr_expanded -= (b * b_coeff).expand()
 
if isinstance(c, sp.Symbol):
 
c_coeff = expr_expanded.coeff(c, 1)
 
rest = expr_expanded - (c * c_coeff).expand()
 
elif is_constant(c):
 
c_coeff = expr_expanded / c
 
rest = 0
 
else:
 
raise ValueError("c must be either a symbol or a constant!")
 
 
a_summands = summands(a_coeff)
 
b_summands = summands(b_coeff)
 
c_summands = summands(c_coeff)
 
 
# replace b + c by a
 
b_plus_c_coeffs = b_summands & c_summands
 
for coeff in b_plus_c_coeffs:
 
rest += a * coeff
 
b_summands -= b_plus_c_coeffs
 
c_summands -= b_plus_c_coeffs
 
 
# replace a - b by c
 
neg_b_summands = {-x for x in b_summands}
 
a_minus_b_coeffs = a_summands & neg_b_summands
 
for coeff in a_minus_b_coeffs:
 
rest += c * coeff
 
a_summands -= a_minus_b_coeffs
 
b_summands -= {-x for x in a_minus_b_coeffs}
 
 
# replace a - c by b
 
neg_c_summands = {-x for x in c_summands}
 
a_minus_c_coeffs = a_summands & neg_c_summands
 
for coeff in a_minus_c_coeffs:
 
rest += b * coeff
 
a_summands -= a_minus_c_coeffs
 
c_summands -= {-x for x in a_minus_c_coeffs}
 
 
# put it back together
 
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
 
 
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]:
only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division.
"""Counts the number of additions, multiplications and division.
Loading