Skip to content
Snippets Groups Projects

Add sqrt; fix domains in function testsuite; update freeze of sp.Pow

Merged Frederik Hennig requested to merge fhennig/sqrt into v2.0-dev
15 files
+ 212
53
Compare changes
  • Side-by-side
  • Inline
Files
15
@@ -212,40 +212,49 @@ class FreezeExpressions:
base = expr.args[0]
exponent = expr.args[1]
base_frozen = self.visit_expr(base)
reciprocal = False
expand_product = False
if exponent.is_Integer:
if exponent == 0:
return PsExpression.make(PsConstant(1))
if exponent.is_negative:
reciprocal = True
exponent = -exponent
if exponent <= sp.Integer(
5
): # TODO: is this a sensible limit? maybe make this configurable.
expand_product = True
if expand_product:
frozen_expr = reduce(
mul,
[base_frozen]
+ [base_frozen.clone() for _ in range(0, int(exponent) - 1)],
)
else:
exponent_frozen = self.visit_expr(exponent)
frozen_expr = PsMathFunction(MathFunctions.Pow)(
base_frozen, exponent_frozen
)
expr_frozen = self.visit_expr(base)
if isinstance(exponent, sp.Rational):
# Decompose rational exponent
num: int = exponent.numerator
denom: int = exponent.denominator
if reciprocal:
one = PsExpression.make(PsConstant(1))
frozen_expr = one / frozen_expr
if denom <= 2 and abs(num) <= 8:
# At most a square root, and at most eight factors
return frozen_expr
reciprocal = False
if num < 0:
reciprocal = True
num = -num
if denom == 2:
expr_frozen = PsMathFunction(MathFunctions.Sqrt)(expr_frozen)
denom = 1
assert denom == 1
# Pairwise multiplication for logarithmic runtime
factors = [expr_frozen] + [expr_frozen.clone() for _ in range(num - 1)]
while len(factors) > 1:
combined = [x * y for x, y in zip(factors[::2], factors[1::2])]
if len(factors) % 2 == 1:
combined.append(factors[-1])
factors = combined
expr_frozen = factors.pop()
if reciprocal:
one = PsExpression.make(PsConstant(1))
expr_frozen = one / expr_frozen
return expr_frozen
# If we got this far, use pow
exponent_frozen = self.visit_expr(exponent)
expr_frozen = PsMathFunction(MathFunctions.Pow)(expr_frozen, exponent_frozen)
return expr_frozen
def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
value = int(expr)
Loading