Skip to content
Snippets Groups Projects
Commit 68b1efe2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Reverted previous changes because they caused unexprected problems with exponentials.

parent eb850beb
No related branches found
No related tags found
1 merge request!207add_subexpressions_for_constants and new_filtered fix
......@@ -94,14 +94,17 @@ def add_subexpressions_for_constants(ac):
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
if is_constant(expr) and abs(expr) != 1:
expr = - constants_to_subexp_dict[- expr] if expr < 0 else constants_to_subexp_dict[expr]
args = list(expr.args)
if len(args) == 0:
return expr
elif len(expr.args) == 0:
return expr
else:
return expr.func(*(visit(a) for a in expr.args))
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
......
......@@ -66,7 +66,8 @@ def test_add_subexpressions_for_constants():
main = [
Assignment(f[0], half * a + half * b + half * c),
Assignment(f[1], - half * a - half * b),
Assignment(f[2], a * sqrt_2 - b * sqrt_2)
Assignment(f[2], a * sqrt_2 - b * sqrt_2),
Assignment(f[3], a**2 + b**2)
]
ac = AssignmentCollection(main)
ac = add_subexpressions_for_constants(ac)
......@@ -87,13 +88,16 @@ def test_add_subexpressions_for_constants():
assert half_subexp is not None
assert sqrt_subexp is not None
for asm in ac.main_assignments:
for asm in ac.main_assignments[:3]:
assert isinstance(asm.rhs, sp.Mul)
assert any(arg == half_subexp for arg in ac.main_assignments[0].rhs.args)
assert any(arg == half_subexp for arg in ac.main_assignments[1].rhs.args)
assert any(arg == sqrt_subexp for arg in ac.main_assignments[2].rhs.args)
# Do not replace exponents!
assert ac.main_assignments[3].rhs == a**2 + b**2
def test_add_subexpressions_for_sums():
subexpressions = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment