Skip to content
Snippets Groups Projects

Add sqrt; fix domains in function testsuite; update freeze of sp.Pow

Merged Frederik Hennig requested to merge fhennig/sqrt into v2.0-dev
15 files
+ 212
53
Compare changes
  • Side-by-side
  • Inline
Files
15
@@ -212,40 +212,49 @@ class FreezeExpressions:
@@ -212,40 +212,49 @@ class FreezeExpressions:
base = expr.args[0]
base = expr.args[0]
exponent = expr.args[1]
exponent = expr.args[1]
base_frozen = self.visit_expr(base)
expr_frozen = self.visit_expr(base)
reciprocal = False
expand_product = False
if isinstance(exponent, sp.Rational):
# Decompose rational exponent
if exponent.is_Integer:
num: int = exponent.numerator
if exponent == 0:
denom: int = exponent.denominator
return PsExpression.make(PsConstant(1))
if exponent.is_negative:
reciprocal = True
exponent = -exponent
if exponent <= sp.Integer(
5
): # TODO: is this a sensible limit? maybe make this configurable.
expand_product = True
if expand_product:
frozen_expr = reduce(
mul,
[base_frozen]
+ [base_frozen.clone() for _ in range(0, int(exponent) - 1)],
)
else:
exponent_frozen = self.visit_expr(exponent)
frozen_expr = PsMathFunction(MathFunctions.Pow)(
base_frozen, exponent_frozen
)
if reciprocal:
if denom <= 2 and abs(num) <= 8:
one = PsExpression.make(PsConstant(1))
# At most a square root, and at most eight factors
frozen_expr = one / frozen_expr
return frozen_expr
reciprocal = False
 
 
if num < 0:
 
reciprocal = True
 
num = -num
 
 
if denom == 2:
 
expr_frozen = PsMathFunction(MathFunctions.Sqrt)(expr_frozen)
 
denom = 1
 
 
assert denom == 1
 
 
# Pairwise multiplication for logarithmic runtime
 
factors = [expr_frozen] + [expr_frozen.clone() for _ in range(num - 1)]
 
while len(factors) > 1:
 
combined = [x * y for x, y in zip(factors[::2], factors[1::2])]
 
if len(factors) % 2 == 1:
 
combined.append(factors[-1])
 
factors = combined
 
 
expr_frozen = factors.pop()
 
 
if reciprocal:
 
one = PsExpression.make(PsConstant(1))
 
expr_frozen = one / expr_frozen
 
 
return expr_frozen
 
 
# If we got this far, use pow
 
exponent_frozen = self.visit_expr(exponent)
 
expr_frozen = PsMathFunction(MathFunctions.Pow)(expr_frozen, exponent_frozen)
 
 
return expr_frozen
def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
value = int(expr)
value = int(expr)
Loading