Skip to content
Snippets Groups Projects
Commit 040040c7 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Merge remote-tracking branch 'origin/master' into nontemporal

parents b4534227 67019520
No related branches found
No related tags found
1 merge request!300Fix nontemporal stores on non-x86 for fields with variable size
Pipeline #43049 failed
......@@ -6,6 +6,7 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs
from sympy.core.numbers import Zero
......@@ -445,8 +446,11 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
......@@ -457,7 +461,13 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected_poly = sp.Poly(expr.collect(symbol), symbol)
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment