Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
bed12f75
Commit
bed12f75
authored
8 years ago
by
Martin Bauer
Browse files
Options
Downloads
Patches
Plain Diff
pystencils: generalized equationcollection
parent
69ec4168
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
equationcollection/equationcollection.py
+71
-56
71 additions, 56 deletions
equationcollection/equationcollection.py
equationcollection/simplifications.py
+7
-12
7 additions, 12 deletions
equationcollection/simplifications.py
sympyextensions.py
+5
-1
5 additions, 1 deletion
sympyextensions.py
with
83 additions
and
69 deletions
equationcollection/equationcollection.py
+
71
−
56
View file @
bed12f75
import
sympy
as
sp
from
copy
import
copy
,
deepcopy
from
pystencils.sympyextensions
import
fastSubs
,
countNumberOfOperations
...
...
@@ -20,52 +21,50 @@ class EquationCollection:
# ----------------------------------------- Creation ---------------------------------------------------------------
def
__init__
(
self
,
equations
,
subExpressions
,
simplificationHints
=
{}
,
subexpressionSymbolNameGenerator
=
None
):
def
__init__
(
self
,
equations
,
subExpressions
,
simplificationHints
=
None
,
subexpressionSymbolNameGenerator
=
None
):
self
.
mainEquations
=
equations
self
.
subexpressions
=
subExpressions
if
simplificationHints
is
None
:
simplificationHints
=
{}
self
.
simplificationHints
=
simplificationHints
def
symbolGen
():
"""
Use this generator to create new unused symbols for subexpressions
"""
counter
=
0
while
True
:
counter
+=
1
newSymbol
=
sp
.
Symbol
(
"
xi_
"
+
str
(
counter
))
if
newSymbol
in
self
.
boundSymbols
:
continue
yield
newSymbol
class
SymbolGen
:
def
__init__
(
self
):
self
.
_ctr
=
0
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
self
.
_ctr
+=
1
return
sp
.
Symbol
(
"
xi_
"
+
str
(
self
.
_ctr
))
if
subexpressionSymbolNameGenerator
is
None
:
self
.
subexpressionSymbolNameGenerator
=
s
ymbolGen
()
self
.
subexpressionSymbolNameGenerator
=
S
ymbolGen
()
else
:
self
.
subexpressionSymbolNameGenerator
=
subexpressionSymbolNameGenerator
def
newWithAdditionalSubexpressions
(
self
,
newEquations
,
additionalSubExpressions
):
"""
Returns a new equation collection, that has `newEquations` as mainEquations.
The `additionalSubExpressions` are appended to the existing subexpressions.
Simplifications hints are copied over.
"""
assert
len
(
self
.
mainEquations
)
==
len
(
newEquations
),
"
Number of update equations cannot be changed
"
res
=
EquationCollection
(
newEquations
,
self
.
subexpressions
+
additionalSubExpressions
,
self
.
simplificationHints
)
res
.
subexpressionSymbolNameGenerator
=
self
.
subexpressionSymbolNameGenerator
def
copy
(
self
,
mainEquations
=
None
,
subexpressions
=
None
):
res
=
deepcopy
(
self
)
if
mainEquations
is
not
None
:
res
.
mainEquations
=
mainEquations
if
subexpressions
is
not
None
:
res
.
subexpressions
=
subexpressions
return
res
def
new
WithSubstitutionsApplied
(
self
,
substitutionDict
,
addSubstitutionsAsSubexpresions
=
False
):
def
copy
WithSubstitutionsApplied
(
self
,
substitutionDict
,
addSubstitutionsAsSubexpres
s
ions
=
False
):
"""
Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
Substitutions are made in the subexpression terms and the main equations
"""
newSubexpressions
=
[
fastSubs
(
eq
,
substitutionDict
)
for
eq
in
self
.
subexpressions
]
newEquations
=
[
fastSubs
(
eq
,
substitutionDict
)
for
eq
in
self
.
mainEquations
]
if
addSubstitutionsAsSubexpresions
:
if
addSubstitutionsAsSubexpres
s
ions
:
newSubexpressions
=
[
sp
.
Eq
(
b
,
a
)
for
a
,
b
in
substitutionDict
.
items
()]
+
newSubexpressions
res
=
EquationCollection
(
newEquations
,
newSubexpressions
,
self
.
simplificationHints
)
res
.
subexpressionSymbolNameGenerator
=
self
.
subexpressionSymbolNameGenerator
return
res
return
self
.
copy
(
newEquations
,
newSubexpressions
)
def
addSimplificationHint
(
self
,
key
,
value
):
"""
...
...
@@ -178,41 +177,45 @@ class EquationCollection:
substitutionDict
[
otherSubexpressionEq
.
lhs
]
=
newLhs
else
:
processedOtherSubexpressionEquations
.
append
(
fastSubs
(
otherSubexpressionEq
,
substitutionDict
))
return
EquationCollection
(
self
.
mainEquations
+
other
.
mainEquations
,
self
.
subexpressions
+
processedOtherSubexpressionEquations
)
return
self
.
copy
(
self
.
mainEquations
+
other
.
mainEquations
,
self
.
subexpressions
+
processedOtherSubexpressionEquations
)
def
extract
(
self
,
symbolsToExtract
):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract
=
set
(
symbolsToExtract
)
newEquations
=
[]
def
getDependentSymbols
(
self
,
symbolSequence
):
"""
Returns a list of symbols that depend on the passed symbols.
"""
subexprMap
=
{
e
.
lhs
:
e
.
rhs
for
e
in
self
.
subexpressions
}
handledSymbols
=
set
()
queue
=
[]
queue
=
list
(
symbolSequence
)
def
addSymbolsFromExpr
(
expr
):
dependentSymbols
=
expr
.
atoms
(
sp
.
Symbol
)
for
ds
in
dependentSymbols
:
if
ds
not
in
handledSymbols
:
queue
.
append
(
ds
)
handledSymbols
.
add
(
ds
)
queue
.
append
(
ds
)
for
eq
in
self
.
allEquations
:
if
eq
.
lhs
in
symbolsToExtract
:
newEquations
.
append
(
eq
)
addSymbolsFromExpr
(
eq
.
rhs
)
handledSymbols
=
set
()
eqMap
=
{
e
.
lhs
:
e
.
rhs
for
e
in
self
.
allEquations
}
while
len
(
queue
)
>
0
:
e
=
queue
.
pop
(
0
)
if
e
not
in
subexprMap
:
if
e
in
handledSymbols
:
continue
else
:
addSymbolsFromExpr
(
subexprMap
[
e
])
if
e
in
eqMap
:
addSymbolsFromExpr
(
eqMap
[
e
])
handledSymbols
.
add
(
e
)
return
handledSymbols
def
extract
(
self
,
symbolsToExtract
):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract
=
set
(
symbolsToExtract
)
dependentSymbols
=
self
.
getDependentSymbols
(
symbolsToExtract
)
newEquations
=
[]
for
eq
in
self
.
allEquations
:
if
eq
.
lhs
in
symbolsToExtract
:
newEquations
.
append
(
eq
)
newSubExpr
=
[
eq
for
eq
in
self
.
subexpressions
if
eq
.
lhs
in
handled
Symbols
and
eq
.
lhs
not
in
symbolsToExtract
]
newSubExpr
=
[
eq
for
eq
in
self
.
subexpressions
if
eq
.
lhs
in
dependent
Symbols
and
eq
.
lhs
not
in
symbolsToExtract
]
return
EquationCollection
(
newEquations
,
newSubExpr
)
def
newWithoutUnusedSubexpressions
(
self
):
...
...
@@ -221,18 +224,30 @@ class EquationCollection:
allLhs
=
[
eq
.
lhs
for
eq
in
self
.
mainEquations
]
return
self
.
extract
(
allLhs
)
def
insertSubexpressions
(
self
):
def
insertSubexpressions
(
self
,
subexpressionSymbolsToKeep
=
set
()
):
"""
Returns a new equation collection by inserting all subexpressions into the main equations
"""
if
len
(
self
.
subexpressions
)
==
0
:
return
EquationCollection
(
self
.
mainEquations
,
self
.
subexpressions
,
self
.
simplificationHints
)
subsDict
=
{
self
.
subexpressions
[
0
].
lhs
:
self
.
subexpressions
[
0
].
rhs
}
return
self
.
copy
()
subexpressionSymbolsToKeep
=
set
(
subexpressionSymbolsToKeep
)
keptSubexpressions
=
[]
if
self
.
subexpressions
[
0
].
lhs
in
subexpressionSymbolsToKeep
:
subsDict
=
{}
keptSubexpressions
=
self
.
subexpressions
[
0
]
else
:
subsDict
=
{
self
.
subexpressions
[
0
].
lhs
:
self
.
subexpressions
[
0
].
rhs
}
subExpr
=
[
e
for
e
in
self
.
subexpressions
]
for
i
in
range
(
1
,
len
(
subExpr
)):
subExpr
[
i
]
=
fastSubs
(
subExpr
[
i
],
subsDict
)
subsDict
[
subExpr
[
i
].
lhs
]
=
subExpr
[
i
].
rhs
if
subExpr
[
i
].
lhs
in
subexpressionSymbolsToKeep
:
keptSubexpressions
.
append
(
subExpr
[
i
])
else
:
subsDict
[
subExpr
[
i
].
lhs
]
=
subExpr
[
i
].
rhs
newEq
=
[
fastSubs
(
eq
,
subsDict
)
for
eq
in
self
.
mainEquations
]
return
EquationCollection
(
newEq
,
[],
self
.
simplificationHint
s
)
return
self
.
copy
(
newEq
,
keptSubexpression
s
)
def
lambdify
(
self
,
symbols
,
module
=
None
,
fixedSymbols
=
{}):
"""
...
...
@@ -241,7 +256,7 @@ class EquationCollection:
:param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g.
'
numpy
'
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification
"""
eqs
=
self
.
new
WithSubstitutionsApplied
(
fixedSymbols
).
insertSubexpressions
().
mainEquations
eqs
=
self
.
copy
WithSubstitutionsApplied
(
fixedSymbols
).
insertSubexpressions
().
mainEquations
lambdas
=
{
eq
.
lhs
:
sp
.
lambdify
(
symbols
,
eq
.
rhs
,
module
)
for
eq
in
eqs
}
def
f
(
*
args
,
**
kwargs
):
...
...
This diff is collapsed.
Click to expand it.
equationcollection/simplifications.py
+
7
−
12
View file @
bed12f75
import
sympy
as
sp
from
pystencils.equationcollection
import
EquationCollection
from
pystencils.sympyextensions
import
replaceAdditive
...
...
@@ -21,21 +20,18 @@ def sympyCSE(equationCollection):
topologicallySortedPairs
=
sp
.
cse_main
.
reps_toposort
([[
e
.
lhs
,
e
.
rhs
]
for
e
in
newSubexpressions
])
newSubexpressions
=
[
sp
.
Eq
(
a
[
0
],
a
[
1
])
for
a
in
topologicallySortedPairs
]
return
EquationCollection
(
modifiedUpdateEquations
,
newSubexpressions
,
equationCollection
.
simplificationHints
,
equationCollection
.
subexpressionSymbolNameGenerator
)
return
equationCollection
.
copy
(
modifiedUpdateEquations
,
newSubexpressions
)
def
applyOnAllEquations
(
equationCollection
,
operation
):
"""
Applies sympy expand operation to all equations in collection
"""
result
=
[
operation
(
s
)
for
s
in
equationCollection
.
mainEquations
]
return
equationCollection
.
newWithAdditionalSubexpressions
(
result
,
[]
)
return
equationCollection
.
copy
(
result
)
def
applyOnAllSubexpressions
(
equationCollection
,
operation
):
return
EquationCollection
(
equationCollection
.
mainEquations
,
[
operation
(
s
)
for
s
in
equationCollection
.
subexpressions
],
equationCollection
.
simplificationHints
,
equationCollection
.
subexpressionSymbolNameGenerator
)
return
equationCollection
.
copy
(
equationCollection
.
mainEquations
,
[
operation
(
s
)
for
s
in
equationCollection
.
subexpressions
])
def
subexpressionSubstitutionInExistingSubexpressions
(
equationCollection
):
...
...
@@ -49,8 +45,7 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
newRhs
=
newRhs
.
subs
(
subExpr
.
rhs
,
subExpr
.
lhs
)
result
.
append
(
sp
.
Eq
(
s
.
lhs
,
newRhs
))
return
EquationCollection
(
equationCollection
.
mainEquations
,
result
,
equationCollection
.
simplificationHints
,
equationCollection
.
subexpressionSymbolNameGenerator
)
return
equationCollection
.
copy
(
equationCollection
.
mainEquations
,
result
)
def
subexpressionSubstitutionInMainEquations
(
equationCollection
):
...
...
@@ -61,7 +56,7 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
for
subExpr
in
equationCollection
.
subexpressions
:
newRhs
=
replaceAdditive
(
newRhs
,
subExpr
.
lhs
,
subExpr
.
rhs
,
requiredMatchReplacement
=
1.0
)
result
.
append
(
sp
.
Eq
(
s
.
lhs
,
newRhs
))
return
equationCollection
.
newWithAdditionalSubexpressions
(
result
,
[]
)
return
equationCollection
.
copy
(
result
)
def
addSubexpressionsForDivisions
(
equationCollection
):
...
...
@@ -80,4 +75,4 @@ def addSubexpressionsForDivisions(equationCollection):
newSymbolGen
=
equationCollection
.
subexpressionSymbolNameGenerator
substitutions
=
{
divisor
:
newSymbol
for
newSymbol
,
divisor
in
zip
(
newSymbolGen
,
divisors
)}
return
equationCollection
.
new
WithSubstitutionsApplied
(
substitutions
,
True
)
return
equationCollection
.
copy
WithSubstitutionsApplied
(
substitutions
,
True
)
This diff is collapsed.
Click to expand it.
sympyextensions.py
+
5
−
1
View file @
bed12f75
...
...
@@ -14,7 +14,11 @@ def fastSubs(term, subsDict):
return
expr
paramList
=
[
visit
(
a
)
for
a
in
expr
.
args
]
return
expr
if
not
paramList
else
expr
.
func
(
*
paramList
)
return
visit
(
term
)
if
len
(
subsDict
)
==
0
:
return
term
else
:
return
visit
(
term
)
def
replaceAdditive
(
expr
,
replacement
,
subExpression
,
requiredMatchReplacement
=
0.5
,
requiredMatchOriginal
=
None
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment