diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index f63328d81781c578786c0019edb1549b83a53c83..861239fdce6560820157af29342be876285382af 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -453,6 +453,72 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
     return rec_sum
 
 
+def summands(expr):
+    return set(expr.args) if isinstance(expr, sp.Add) else {expr}
+
+
+def simplify_by_equality(expr, a, b, c):
+    """
+    Uses the equality a = b + c, where a and b must be symbols, to simplify expr 
+    by attempting to express additive combinations of two quantities by the third.
+
+    This works on expressions that are reducible to the form 
+    :math:`a * (...) + b * (...) + c * (...)`,
+    without any mixed terms of a, b and c.
+    """
+    if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
+        raise ValueError("a and b must be symbols.")
+
+    c = sp.sympify(c)
+
+    if not (isinstance(c, sp.Symbol) or is_constant(c)):
+        raise ValueError("c must be either a symbol or a constant!")
+
+    expr = sp.sympify(expr)
+
+    expr_expanded = sp.expand(expr)
+    a_coeff = expr_expanded.coeff(a, 1)
+    expr_expanded -= (a * a_coeff).expand()
+    b_coeff = expr_expanded.coeff(b, 1)
+    expr_expanded -= (b * b_coeff).expand()
+    if isinstance(c, sp.Symbol):
+        c_coeff = expr_expanded.coeff(c, 1)
+        rest = expr_expanded - (c * c_coeff).expand()
+    else:
+        c_coeff = expr_expanded / c
+        rest = 0
+
+    a_summands = summands(a_coeff)
+    b_summands = summands(b_coeff)
+    c_summands = summands(c_coeff)
+
+    # replace b + c by a
+    b_plus_c_coeffs = b_summands & c_summands
+    for coeff in b_plus_c_coeffs:
+        rest += a * coeff
+    b_summands -= b_plus_c_coeffs
+    c_summands -= b_plus_c_coeffs
+
+    # replace a - b by c
+    neg_b_summands = {-x for x in b_summands}
+    a_minus_b_coeffs = a_summands & neg_b_summands
+    for coeff in a_minus_b_coeffs:
+        rest += c * coeff
+    a_summands -= a_minus_b_coeffs
+    b_summands -= {-x for x in a_minus_b_coeffs}
+
+    # replace a - c by b
+    neg_c_summands = {-x for x in c_summands}
+    a_minus_c_coeffs = a_summands & neg_c_summands
+    for coeff in a_minus_c_coeffs:
+        rest += b * coeff
+    a_summands -= a_minus_c_coeffs
+    c_summands -= {-x for x in a_minus_c_coeffs}
+
+    # put it back together
+    return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
+
+
 def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
                      only_type: Optional[str] = 'real') -> Dict[str, int]:
     """Counts the number of additions, multiplications and division.
diff --git a/pystencils_tests/test_sympyextensions.py b/pystencils_tests/test_sympyextensions.py
index 82e0ef40206a293b99ed568ab6b7c75f28fd43a7..38a138d2b0d4a52d56214a6f2c5c4f0a7dedfd9c 100644
--- a/pystencils_tests/test_sympyextensions.py
+++ b/pystencils_tests/test_sympyextensions.py
@@ -1,11 +1,13 @@
 import sympy
 import numpy as np
+import sympy as sp
 import pystencils
 
 from pystencils.sympyextensions import replace_second_order_products
 from pystencils.sympyextensions import remove_higher_order_terms
 from pystencils.sympyextensions import complete_the_squares_in_exp
 from pystencils.sympyextensions import extract_most_common_factor
+from pystencils.sympyextensions import simplify_by_equality
 from pystencils.sympyextensions import count_operations
 from pystencils.sympyextensions import common_denominator
 from pystencils.sympyextensions import get_symmetric_part
@@ -176,3 +178,26 @@ def test_get_symmetric_part():
     sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))
 
     assert sym_part == expected_result
+
+
+def test_simplify_by_equality():
+    x, y, z = sp.symbols('x, y, z')
+    p, q = sp.symbols('p, q')
+
+    #   Let x = y + z
+    expr = x * p - y * p + z * q
+    expr = simplify_by_equality(expr, x, y, z)
+    assert expr == z * p + z * q
+
+    expr = x * (p - 2 * q) + 2 * q * z
+    expr = simplify_by_equality(expr, x, y, z)
+    assert expr == x * p - 2 * q * y
+
+    expr = x * (y + z) - y * z
+    expr = simplify_by_equality(expr, x, y, z)
+    assert expr == x*y + z**2
+
+    #   Let x = y + 2
+    expr = x * p - 2 * p
+    expr = simplify_by_equality(expr, x, y, 2)
+    assert expr == y * p