Skip to content
Snippets Groups Projects
Commit 0a29a147 authored by Martin Bauer's avatar Martin Bauer
Browse files

boundary generatlization

parent d6d843fb
No related branches found
No related tags found
No related merge requests found
...@@ -226,6 +226,21 @@ class EquationCollection: ...@@ -226,6 +226,21 @@ class EquationCollection:
allLhs = [eq.lhs for eq in self.mainEquations] allLhs = [eq.lhs for eq in self.mainEquations]
return self.extract(allLhs) return self.extract(allLhs)
def insertSubexpression(self, symbol):
newSubexpressions = []
subsDict = None
for se in self.subexpressions:
if se.lhs == symbol:
subsDict = {se.lhs: se.rhs}
else:
newSubexpressions.append(se)
if subsDict is None:
return self
newSubexpressions = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in newSubexpressions]
newEqs = [sp.Eq(eq.lhs, fastSubs(eq.rhs, subsDict)) for eq in self.mainEquations]
return self.copy(newEqs, newSubexpressions)
def insertSubexpressions(self, subexpressionSymbolsToKeep=set()): def insertSubexpressions(self, subexpressionSymbolsToKeep=set()):
"""Returns a new equation collection by inserting all subexpressions into the main equations""" """Returns a new equation collection by inserting all subexpressions into the main equations"""
if len(self.subexpressions) == 0: if len(self.subexpressions) == 0:
......
...@@ -25,13 +25,13 @@ def sympyCSE(equationCollection): ...@@ -25,13 +25,13 @@ def sympyCSE(equationCollection):
def applyOnAllEquations(equationCollection, operation): def applyOnAllEquations(equationCollection, operation):
"""Applies sympy expand operation to all equations in collection""" """Applies sympy expand operation to all equations in collection"""
result = [operation(s) for s in equationCollection.mainEquations] result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.mainEquations]
return equationCollection.copy(result) return equationCollection.copy(result)
def applyOnAllSubexpressions(equationCollection, operation): def applyOnAllSubexpressions(equationCollection, operation):
return equationCollection.copy(equationCollection.mainEquations, result = [sp.Eq(eq.lhs, operation(eq.rhs)) for eq in equationCollection.subexpressions]
[operation(s) for s in equationCollection.subexpressions]) return equationCollection.copy(equationCollection.mainEquations, result)
def subexpressionSubstitutionInExistingSubexpressions(equationCollection): def subexpressionSubstitutionInExistingSubexpressions(equationCollection):
...@@ -60,6 +60,8 @@ def subexpressionSubstitutionInMainEquations(equationCollection): ...@@ -60,6 +60,8 @@ def subexpressionSubstitutionInMainEquations(equationCollection):
def addSubexpressionsForDivisions(equationCollection): def addSubexpressionsForDivisions(equationCollection):
"""Introduces subexpressions for all divisions which have no constant in the denominator.
e.g. :math:`\frac{1}{x}` is replaced, :math:`\frac{1}{3}` is not replaced."""
divisors = set() divisors = set()
def searchDivisors(term): def searchDivisors(term):
......
...@@ -260,12 +260,33 @@ def extractMostCommonFactor(term): ...@@ -260,12 +260,33 @@ def extractMostCommonFactor(term):
coeffDict = term.as_coefficients_dict() coeffDict = term.as_coefficients_dict()
counter = Counter([Abs(v) for v in coeffDict.values()]) counter = Counter([Abs(v) for v in coeffDict.values()])
commonFactor, occurances = max(counter.items(), key=operator.itemgetter(1)) commonFactor, occurrences = max(counter.items(), key=operator.itemgetter(1))
if occurances == 1 and (1 in counter): if occurrences == 1 and (1 in counter):
commonFactor = 1 commonFactor = 1
return commonFactor, term / commonFactor return commonFactor, term / commonFactor
def mostCommonTermFactorization(term):
commonFactor, term = extractMostCommonFactor(term)
factorization = sp.factor(term)
if factorization.is_Mul:
symbolsInFactorization = []
constantsInFactorization = 1
for arg in factorization.args:
if len(arg.atoms(sp.Symbol)) == 0:
constantsInFactorization *= arg
else:
symbolsInFactorization.append(arg)
if len(symbolsInFactorization) <= 1:
return sp.Mul(commonFactor, term, evaluate=False)
else:
return sp.Mul(commonFactor, *symbolsInFactorization[:-1],
constantsInFactorization * symbolsInFactorization[-1])
else:
return sp.Mul(commonFactor, term, evaluate=False)
def countNumberOfOperations(term): def countNumberOfOperations(term):
""" """
Counts the number of additions, multiplications and division Counts the number of additions, multiplications and division
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment