Skip to content
Snippets Groups Projects
Commit bed12f75 authored by Martin Bauer's avatar Martin Bauer
Browse files

pystencils: generalized equationcollection

parent 69ec4168
No related branches found
No related tags found
No related merge requests found
import sympy as sp
from copy import copy, deepcopy
from pystencils.sympyextensions import fastSubs, countNumberOfOperations
......@@ -20,52 +21,50 @@ class EquationCollection:
# ----------------------------------------- Creation ---------------------------------------------------------------
def __init__(self, equations, subExpressions, simplificationHints={}, subexpressionSymbolNameGenerator=None):
def __init__(self, equations, subExpressions, simplificationHints=None, subexpressionSymbolNameGenerator=None):
self.mainEquations = equations
self.subexpressions = subExpressions
if simplificationHints is None:
simplificationHints = {}
self.simplificationHints = simplificationHints
def symbolGen():
"""Use this generator to create new unused symbols for subexpressions"""
counter = 0
while True:
counter += 1
newSymbol = sp.Symbol("xi_" + str(counter))
if newSymbol in self.boundSymbols:
continue
yield newSymbol
class SymbolGen:
def __init__(self):
self._ctr = 0
def __iter__(self):
return self
def __next__(self):
self._ctr += 1
return sp.Symbol("xi_" + str(self._ctr))
if subexpressionSymbolNameGenerator is None:
self.subexpressionSymbolNameGenerator = symbolGen()
self.subexpressionSymbolNameGenerator = SymbolGen()
else:
self.subexpressionSymbolNameGenerator = subexpressionSymbolNameGenerator
def newWithAdditionalSubexpressions(self, newEquations, additionalSubExpressions):
"""
Returns a new equation collection, that has `newEquations` as mainEquations.
The `additionalSubExpressions` are appended to the existing subexpressions.
Simplifications hints are copied over.
"""
assert len(self.mainEquations) == len(newEquations), "Number of update equations cannot be changed"
res = EquationCollection(newEquations,
self.subexpressions + additionalSubExpressions,
self.simplificationHints)
res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
def copy(self, mainEquations=None, subexpressions=None):
res = deepcopy(self)
if mainEquations is not None:
res.mainEquations = mainEquations
if subexpressions is not None:
res.subexpressions = subexpressions
return res
def newWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpresions=False):
def copyWithSubstitutionsApplied(self, substitutionDict, addSubstitutionsAsSubexpressions=False):
"""
Returns a new equation collection, where terms are substituted according to the passed `substitutionDict`.
Substitutions are made in the subexpression terms and the main equations
"""
newSubexpressions = [fastSubs(eq, substitutionDict) for eq in self.subexpressions]
newEquations = [fastSubs(eq, substitutionDict) for eq in self.mainEquations]
if addSubstitutionsAsSubexpresions:
if addSubstitutionsAsSubexpressions:
newSubexpressions = [sp.Eq(b, a) for a, b in substitutionDict.items()] + newSubexpressions
res = EquationCollection(newEquations, newSubexpressions, self.simplificationHints)
res.subexpressionSymbolNameGenerator = self.subexpressionSymbolNameGenerator
return res
return self.copy(newEquations, newSubexpressions)
def addSimplificationHint(self, key, value):
"""
......@@ -178,41 +177,45 @@ class EquationCollection:
substitutionDict[otherSubexpressionEq.lhs] = newLhs
else:
processedOtherSubexpressionEquations.append(fastSubs(otherSubexpressionEq, substitutionDict))
return EquationCollection(self.mainEquations + other.mainEquations,
self.subexpressions + processedOtherSubexpressionEquations)
return self.copy(self.mainEquations + other.mainEquations,
self.subexpressions + processedOtherSubexpressionEquations)
def extract(self, symbolsToExtract):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract = set(symbolsToExtract)
newEquations = []
def getDependentSymbols(self, symbolSequence):
"""Returns a list of symbols that depend on the passed symbols."""
subexprMap = {e.lhs: e.rhs for e in self.subexpressions}
handledSymbols = set()
queue = []
queue = list(symbolSequence)
def addSymbolsFromExpr(expr):
dependentSymbols = expr.atoms(sp.Symbol)
for ds in dependentSymbols:
if ds not in handledSymbols:
queue.append(ds)
handledSymbols.add(ds)
queue.append(ds)
for eq in self.allEquations:
if eq.lhs in symbolsToExtract:
newEquations.append(eq)
addSymbolsFromExpr(eq.rhs)
handledSymbols = set()
eqMap = {e.lhs: e.rhs for e in self.allEquations}
while len(queue) > 0:
e = queue.pop(0)
if e not in subexprMap:
if e in handledSymbols:
continue
else:
addSymbolsFromExpr(subexprMap[e])
if e in eqMap:
addSymbolsFromExpr(eqMap[e])
handledSymbols.add(e)
return handledSymbols
def extract(self, symbolsToExtract):
"""
Creates a new equation collection with equations that have symbolsToExtract as left-hand-sides and
only the necessary subexpressions that are used in these equations
"""
symbolsToExtract = set(symbolsToExtract)
dependentSymbols = self.getDependentSymbols(symbolsToExtract)
newEquations = []
for eq in self.allEquations:
if eq.lhs in symbolsToExtract:
newEquations.append(eq)
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in handledSymbols and eq.lhs not in symbolsToExtract]
newSubExpr = [eq for eq in self.subexpressions if eq.lhs in dependentSymbols and eq.lhs not in symbolsToExtract]
return EquationCollection(newEquations, newSubExpr)
def newWithoutUnusedSubexpressions(self):
......@@ -221,18 +224,30 @@ class EquationCollection:
allLhs = [eq.lhs for eq in self.mainEquations]
return self.extract(allLhs)
def insertSubexpressions(self):
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
"""Returns a new equation collection by inserting all subexpressions into the main equations"""
if len(self.subexpressions) == 0:
return EquationCollection(self.mainEquations, self.subexpressions, self.simplificationHints)
subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
return self.copy()
subexpressionSymbolsToKeep = set(subexpressionSymbolsToKeep)
keptSubexpressions = []
if self.subexpressions[0].lhs in subexpressionSymbolsToKeep:
subsDict = {}
keptSubexpressions = self.subexpressions[0]
else:
subsDict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
subExpr = [e for e in self.subexpressions]
for i in range(1, len(subExpr)):
subExpr[i] = fastSubs(subExpr[i], subsDict)
subsDict[subExpr[i].lhs] = subExpr[i].rhs
if subExpr[i].lhs in subexpressionSymbolsToKeep:
keptSubexpressions.append(subExpr[i])
else:
subsDict[subExpr[i].lhs] = subExpr[i].rhs
newEq = [fastSubs(eq, subsDict) for eq in self.mainEquations]
return EquationCollection(newEq, [], self.simplificationHints)
return self.copy(newEq, keptSubexpressions)
def lambdify(self, symbols, module=None, fixedSymbols={}):
"""
......@@ -241,7 +256,7 @@ class EquationCollection:
:param module: same as sympy.lambdify paramter of same same, i.e. which module to use e.g. 'numpy'
:param fixedSymbols: dictionary with substitutions, that are applied before lambdification
"""
eqs = self.newWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
eqs = self.copyWithSubstitutionsApplied(fixedSymbols).insertSubexpressions().mainEquations
lambdas = {eq.lhs: sp.lambdify(symbols, eq.rhs, module) for eq in eqs}
def f(*args, **kwargs):
......
import sympy as sp
from pystencils.equationcollection import EquationCollection
from pystencils.sympyextensions import replaceAdditive
......@@ -21,21 +20,18 @@ def sympyCSE(equationCollection):
topologicallySortedPairs = sp.cse_main.reps_toposort([[e.lhs, e.rhs] for e in newSubexpressions])
newSubexpressions = [sp.Eq(a[0], a[1]) for a in topologicallySortedPairs]
return EquationCollection(modifiedUpdateEquations, newSubexpressions, equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
return equationCollection.copy(modifiedUpdateEquations, newSubexpressions)
def applyOnAllEquations(equationCollection, operation):
"""Applies sympy expand operation to all equations in collection"""
result = [operation(s) for s in equationCollection.mainEquations]
return equationCollection.newWithAdditionalSubexpressions(result, [])
return equationCollection.copy(result)
def applyOnAllSubexpressions(equationCollection, operation):
return EquationCollection(equationCollection.mainEquations,
[operation(s) for s in equationCollection.subexpressions],
equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
return equationCollection.copy(equationCollection.mainEquations,
[operation(s) for s in equationCollection.subexpressions])
def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
......@@ -49,8 +45,7 @@ def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
newRhs = newRhs.subs(subExpr.rhs, subExpr.lhs)
result.append(sp.Eq(s.lhs, newRhs))
return EquationCollection(equationCollection.mainEquations, result, equationCollection.simplificationHints,
equationCollection.subexpressionSymbolNameGenerator)
return equationCollection.copy(equationCollection.mainEquations, result)
def subexpressionSubstitutionInMainEquations(equationCollection):
......@@ -61,7 +56,7 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
for subExpr in equationCollection.subexpressions:
newRhs = replaceAdditive(newRhs, subExpr.lhs, subExpr.rhs, requiredMatchReplacement=1.0)
result.append(sp.Eq(s.lhs, newRhs))
return equationCollection.newWithAdditionalSubexpressions(result, [])
return equationCollection.copy(result)
def addSubexpressionsForDivisions(equationCollection):
......@@ -80,4 +75,4 @@ def addSubexpressionsForDivisions(equationCollection):
newSymbolGen = equationCollection.subexpressionSymbolNameGenerator
substitutions = {divisor: newSymbol for newSymbol, divisor in zip(newSymbolGen, divisors)}
return equationCollection.newWithSubstitutionsApplied(substitutions, True)
return equationCollection.copyWithSubstitutionsApplied(substitutions, True)
......@@ -14,7 +14,11 @@ def fastSubs(term, subsDict):
return expr
paramList = [visit(a) for a in expr.args]
return expr if not paramList else expr.func(*paramList)
return visit(term)
if len(subsDict) == 0:
return term
else:
return visit(term)
def replaceAdditive(expr, replacement, subExpression, requiredMatchReplacement=0.5, requiredMatchOriginal=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment