Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
8c53c16a
Commit
8c53c16a
authored
3 years ago
by
Frederik Hennig
Committed by
Markus Holzer
3 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Added simplify_by_equality
parent
825be1df
Branches
Branches containing commit
No related tags found
1 merge request
!286
Added simplify_by_equality
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
pystencils/sympyextensions.py
+66
-0
66 additions, 0 deletions
pystencils/sympyextensions.py
pystencils_tests/test_sympyextensions.py
+25
-0
25 additions, 0 deletions
pystencils_tests/test_sympyextensions.py
with
91 additions
and
0 deletions
pystencils/sympyextensions.py
+
66
−
0
View file @
8c53c16a
...
...
@@ -453,6 +453,72 @@ def recursive_collect(expr, symbols, order_by_occurences=False):
return
rec_sum
def
summands
(
expr
):
return
set
(
expr
.
args
)
if
isinstance
(
expr
,
sp
.
Add
)
else
{
expr
}
def
simplify_by_equality
(
expr
,
a
,
b
,
c
):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if
not
isinstance
(
a
,
sp
.
Symbol
)
or
not
isinstance
(
b
,
sp
.
Symbol
):
raise
ValueError
(
"
a and b must be symbols.
"
)
c
=
sp
.
sympify
(
c
)
if
not
(
isinstance
(
c
,
sp
.
Symbol
)
or
is_constant
(
c
)):
raise
ValueError
(
"
c must be either a symbol or a constant!
"
)
expr
=
sp
.
sympify
(
expr
)
expr_expanded
=
sp
.
expand
(
expr
)
a_coeff
=
expr_expanded
.
coeff
(
a
,
1
)
expr_expanded
-=
(
a
*
a_coeff
).
expand
()
b_coeff
=
expr_expanded
.
coeff
(
b
,
1
)
expr_expanded
-=
(
b
*
b_coeff
).
expand
()
if
isinstance
(
c
,
sp
.
Symbol
):
c_coeff
=
expr_expanded
.
coeff
(
c
,
1
)
rest
=
expr_expanded
-
(
c
*
c_coeff
).
expand
()
else
:
c_coeff
=
expr_expanded
/
c
rest
=
0
a_summands
=
summands
(
a_coeff
)
b_summands
=
summands
(
b_coeff
)
c_summands
=
summands
(
c_coeff
)
# replace b + c by a
b_plus_c_coeffs
=
b_summands
&
c_summands
for
coeff
in
b_plus_c_coeffs
:
rest
+=
a
*
coeff
b_summands
-=
b_plus_c_coeffs
c_summands
-=
b_plus_c_coeffs
# replace a - b by c
neg_b_summands
=
{
-
x
for
x
in
b_summands
}
a_minus_b_coeffs
=
a_summands
&
neg_b_summands
for
coeff
in
a_minus_b_coeffs
:
rest
+=
c
*
coeff
a_summands
-=
a_minus_b_coeffs
b_summands
-=
{
-
x
for
x
in
a_minus_b_coeffs
}
# replace a - c by b
neg_c_summands
=
{
-
x
for
x
in
c_summands
}
a_minus_c_coeffs
=
a_summands
&
neg_c_summands
for
coeff
in
a_minus_c_coeffs
:
rest
+=
b
*
coeff
a_summands
-=
a_minus_c_coeffs
c_summands
-=
{
-
x
for
x
in
a_minus_c_coeffs
}
# put it back together
return
(
rest
+
a
*
sum
(
a_summands
)
+
b
*
sum
(
b_summands
)
+
c
*
sum
(
c_summands
)).
expand
()
def
count_operations
(
term
:
Union
[
sp
.
Expr
,
List
[
sp
.
Expr
],
List
[
Assignment
]],
only_type
:
Optional
[
str
]
=
'
real
'
)
->
Dict
[
str
,
int
]:
"""
Counts the number of additions, multiplications and division.
...
...
This diff is collapsed.
Click to expand it.
pystencils_tests/test_sympyextensions.py
+
25
−
0
View file @
8c53c16a
import
sympy
import
numpy
as
np
import
sympy
as
sp
import
pystencils
from
pystencils.sympyextensions
import
replace_second_order_products
from
pystencils.sympyextensions
import
remove_higher_order_terms
from
pystencils.sympyextensions
import
complete_the_squares_in_exp
from
pystencils.sympyextensions
import
extract_most_common_factor
from
pystencils.sympyextensions
import
simplify_by_equality
from
pystencils.sympyextensions
import
count_operations
from
pystencils.sympyextensions
import
common_denominator
from
pystencils.sympyextensions
import
get_symmetric_part
...
...
@@ -176,3 +178,26 @@ def test_get_symmetric_part():
sym_part
=
get_symmetric_part
(
expr
,
sympy
.
symbols
(
f
'
y z
'
))
assert
sym_part
==
expected_result
def
test_simplify_by_equality
():
x
,
y
,
z
=
sp
.
symbols
(
'
x, y, z
'
)
p
,
q
=
sp
.
symbols
(
'
p, q
'
)
# Let x = y + z
expr
=
x
*
p
-
y
*
p
+
z
*
q
expr
=
simplify_by_equality
(
expr
,
x
,
y
,
z
)
assert
expr
==
z
*
p
+
z
*
q
expr
=
x
*
(
p
-
2
*
q
)
+
2
*
q
*
z
expr
=
simplify_by_equality
(
expr
,
x
,
y
,
z
)
assert
expr
==
x
*
p
-
2
*
q
*
y
expr
=
x
*
(
y
+
z
)
-
y
*
z
expr
=
simplify_by_equality
(
expr
,
x
,
y
,
z
)
assert
expr
==
x
*
y
+
z
**
2
# Let x = y + 2
expr
=
x
*
p
-
2
*
p
expr
=
simplify_by_equality
(
expr
,
x
,
y
,
2
)
assert
expr
==
y
*
p
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment