Skip to content
Snippets Groups Projects

Fix: `recursive_collect` now fails silently

Merged Frederik Hennig requested to merge da15siwa/pystencils:fix_extract_constants into master
1 file
+ 13
3
Compare changes
  • Side-by-side
  • Inline
@@ -6,6 +6,7 @@ from functools import partial, reduce
@@ -6,6 +6,7 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp
import sympy as sp
 
from sympy import PolynomialError
from sympy.functions import Abs
from sympy.functions import Abs
from sympy.core.numbers import Zero
from sympy.core.numbers import Zero
@@ -442,11 +443,14 @@ def extract_most_common_factor(term):
@@ -442,11 +443,14 @@ def extract_most_common_factor(term):
def recursive_collect(expr, symbols, order_by_occurences=False):
def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
and so on.
 
 
``expr`` must be rewritable as a polynomial in the given ``symbols``.
 
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
Args:
expr: A sympy expression
expr: A sympy expression.
symbols: A sequence of symbols
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
most often in the expression.
@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
if len(symbols) == 0:
if len(symbols) == 0:
return expr
return expr
symbol = symbols[0]
symbol = symbols[0]
collected_poly = sp.Poly(expr.collect(symbol), symbol)
collected = expr.collect(symbol)
 
 
try:
 
collected_poly = sp.Poly(collected, symbol)
 
except PolynomialError:
 
return expr
 
coeffs = collected_poly.all_coeffs()[::-1]
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
return rec_sum
Loading