Skip to content
Snippets Groups Projects
Commit 181e9b17 authored by Martin Bauer's avatar Martin Bauer
Browse files

Code Quality

- switched to google style docstrings
- removed dead code
- started to annotate types
parent 66322279
Branches
Tags
No related merge requests found
from lbmpy.chapman_enskog.derivative import DiffOperator, Diff, expandUsingLinearity, expandUsingProductRule, \ from lbmpy.chapman_enskog.derivative import DiffOperator, Diff, expandUsingLinearity, expandUsingProductRule, \
normalizeDiffOrder, chapmanEnskogDerivativeExpansion, chapmanEnskogDerivativeRecombination normalizeDiffOrder, chapmanEnskogDerivativeExpansion, chapmanEnskogDerivativeRecombination
from lbmpy.chapman_enskog.chapman_enskog import getExpandedName, LbMethodEqMoments, insertMoments, takeMoments, \ from lbmpy.chapman_enskog.chapman_enskog import LbMethodEqMoments, insertMoments, takeMoments, \
CeMoment, chainSolveAndSubstitute, timeDiffSelector, momentSelector, ChapmanEnskogAnalysis CeMoment, chainSolveAndSubstitute, timeDiffSelector, momentSelector, ChapmanEnskogAnalysis
...@@ -8,7 +8,8 @@ from lbmpy.chapman_enskog import Diff, expandUsingLinearity, expandUsingProductR ...@@ -8,7 +8,8 @@ from lbmpy.chapman_enskog import Diff, expandUsingLinearity, expandUsingProductR
from lbmpy.chapman_enskog import DiffOperator, normalizeDiffOrder, chapmanEnskogDerivativeExpansion, \ from lbmpy.chapman_enskog import DiffOperator, normalizeDiffOrder, chapmanEnskogDerivativeExpansion, \
chapmanEnskogDerivativeRecombination chapmanEnskogDerivativeRecombination
from lbmpy.chapman_enskog.derivative import collectDerivatives, createNestedDiff from lbmpy.chapman_enskog.derivative import collectDerivatives, createNestedDiff
from lbmpy.moments import discreteMoment, momentMatrix, polynomialToExponentRepresentation, getMomentIndices from lbmpy.moments import discreteMoment, momentMatrix, polynomialToExponentRepresentation, getMomentIndices, \
nonAliasedMoment
from pystencils.cache import diskcache from pystencils.cache import diskcache
from pystencils.sympyextensions import normalizeProduct, multidimensionalSummation, kroneckerDelta from pystencils.sympyextensions import normalizeProduct, multidimensionalSummation, kroneckerDelta
from pystencils.sympyextensions import productSymmetric from pystencils.sympyextensions import productSymmetric
...@@ -17,14 +18,6 @@ from pystencils.sympyextensions import productSymmetric ...@@ -17,14 +18,6 @@ from pystencils.sympyextensions import productSymmetric
# --------------------------------------------- Helper Functions ------------------------------------------------------- # --------------------------------------------- Helper Functions -------------------------------------------------------
def getExpandedName(originalObject, number):
import warnings
warnings.warn("Deprecated!")
name = originalObject.name
newName = name + "^{(%i)}" % (number,)
return originalObject.func(newName)
def expandedSymbol(name, subscript=None, superscript=None, **kwargs): def expandedSymbol(name, subscript=None, superscript=None, **kwargs):
if subscript is not None: if subscript is not None:
name += "_{%s}" % (subscript,) name += "_{%s}" % (subscript,)
...@@ -32,49 +25,10 @@ def expandedSymbol(name, subscript=None, superscript=None, **kwargs): ...@@ -32,49 +25,10 @@ def expandedSymbol(name, subscript=None, superscript=None, **kwargs):
name += "^{(%s)}" % (superscript,) name += "^{(%s)}" % (superscript,)
return sp.Symbol(name, **kwargs) return sp.Symbol(name, **kwargs)
# -------------------------------- Summation Convention -------------------------------------------------------------
def getOccurrenceCountOfIndex(term, index):
if isinstance(term, Diff):
return getOccurrenceCountOfIndex(term.arg, index) + (1 if term.target == index else 0)
elif isinstance(term, sp.Symbol):
return 1 if term.name.endswith("_" + str(index)) else 0
else:
return 0
def replaceIndex(term, oldIndex, newIndex):
if isinstance(term, Diff):
newArg = replaceIndex(term.arg, oldIndex, newIndex)
newLabel = newIndex if term.target == oldIndex else term.target
return Diff(newArg, newLabel, term.superscript)
elif isinstance(term, sp.Symbol):
if term.name.endswith("_" + str(oldIndex)):
baseName = term.name[:-(len(str(oldIndex))+1)]
return sp.Symbol(baseName + "_" + str(newIndex))
else:
return term
else:
newArgs = [replaceIndex(a, oldIndex, newIndex) for a in term.args]
return term.func(*newArgs) if newArgs else term
# Problem: when there are more than two repeated indices... which one to replace?
# ---------------------------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------------------------
def momentAliasing(momentTuple):
moment = list(momentTuple)
result = []
for element in moment:
if element > 2:
result.append(2 - (element % 2))
else:
result.append(element)
return tuple(result)
class CeMoment(sp.Symbol): class CeMoment(sp.Symbol):
def __new__(cls, name, *args, **kwds): def __new__(cls, name, *args, **kwds):
obj = CeMoment.__xnew_cached_(cls, name, *args, **kwds) obj = CeMoment.__xnew_cached_(cls, name, *args, **kwds)
...@@ -256,7 +210,7 @@ def takeMoments(eqn, pdfToMomentName=(('f', '\Pi'), ('\Omega f', '\\Upsilon')), ...@@ -256,7 +210,7 @@ def takeMoments(eqn, pdfToMomentName=(('f', '\Pi'), ('\Omega f', '\\Upsilon')),
momentTuple = tuple(momentTuple) momentTuple = tuple(momentTuple)
if useOneNeighborhoodAliasing: if useOneNeighborhoodAliasing:
momentTuple = momentAliasing(momentTuple) momentTuple = nonAliasedMoment(momentTuple)
result = CeMoment(fIndex.momentName, momentTuple, fIndex.superscript) result = CeMoment(fIndex.momentName, momentTuple, fIndex.superscript)
if derivativeTerm is not None: if derivativeTerm is not None:
result = derivativeTerm.changeArgRecursive(result) result = derivativeTerm.changeArgRecursive(result)
......
# -*- coding: utf-8 -*-
""" """
Module Overview Module Overview
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
...@@ -25,9 +25,9 @@ All moment polynomials have to use ``MOMENT_SYMBOLS`` (which is a module variabl ...@@ -25,9 +25,9 @@ All moment polynomials have to use ``MOMENT_SYMBOLS`` (which is a module variabl
Example :: Example ::
from lbmpy.moments import MOMENT_SYMBOLS >>> from lbmpy.moments import MOMENT_SYMBOLS
x, y, z = MOMENT_SYMBOLS >>> x, y, z = MOMENT_SYMBOLS
secondOrderMoment = x*y + y*z >>> secondOrderMoment = x*y + y*z
Functions Functions
...@@ -36,65 +36,17 @@ Functions ...@@ -36,65 +36,17 @@ Functions
""" """
import itertools import itertools
import math import math
from collections import Counter, defaultdict
from copy import copy from copy import copy
from collections import Counter, defaultdict
from typing import Iterable, List, Sequence, Tuple, TypeVar, Optional
import sympy as sp import sympy as sp
from pystencils.cache import memorycache from pystencils.cache import memorycache
from pystencils.sympyextensions import removeHigherOrderTerms from pystencils.sympyextensions import removeHigherOrderTerms
MOMENT_SYMBOLS = sp.symbols("x y z")
MOMENT_SYMBOLS = sp.symbols('x y z')
def __uniqueList(seq): T = TypeVar('T')
seen = {}
result = []
for item in seq:
if item in seen:
continue
seen[item] = 1
result.append(item)
return result
def __uniquePermutations(elements):
if len(elements) == 1:
yield (elements[0],)
else:
unique_elements = set(elements)
for first_element in unique_elements:
remaining_elements = list(elements)
remaining_elements.remove(first_element)
for sub_permutation in __uniquePermutations(remaining_elements):
yield (first_element,) + sub_permutation
def __generateFixedSumTuples(tupleLength, tupleSum, allowedValues=None, ordered=False):
if not allowedValues:
allowedValues = list(range(0, tupleSum + 1))
assert (0 in allowedValues)
def recursive_helper(currentList, position, totalSum):
newPosition = position + 1
if newPosition < len(currentList):
for i in allowedValues:
currentList[position] = i
newSum = totalSum - i
if newSum < 0:
continue
for item in recursive_helper(currentList, newPosition, newSum):
yield item
else:
if totalSum in allowedValues:
currentList[-1] = totalSum
if not ordered:
yield tuple(currentList)
if ordered and currentList == sorted(currentList, reverse=True):
yield tuple(currentList)
return recursive_helper([0] * tupleLength, 0, tupleSum)
# ------------------------------ Discrete (Exponent Tuples) ------------------------------------------------------------ # ------------------------------ Discrete (Exponent Tuples) ------------------------------------------------------------
...@@ -128,12 +80,12 @@ def pickRepresentativeMoments(moments): ...@@ -128,12 +80,12 @@ def pickRepresentativeMoments(moments):
def momentPermutations(exponentTuple): def momentPermutations(exponentTuple):
"""Returns all (unique) permutations of the given tuple""" """Returns all (unique) permutations of the given tuple"""
return __uniquePermutations(exponentTuple) return __unique_permutations(exponentTuple)
def momentsOfOrder(order, dim=3, includePermutations=True): def momentsOfOrder(order, dim=3, includePermutations=True):
"""All tuples of length 'dim' which sum equals 'order'""" """All tuples of length 'dim' which sum equals 'order'"""
for item in __generateFixedSumTuples(dim, order, ordered=not includePermutations): for item in __fixed_sum_tuples(dim, order, ordered=not includePermutations):
assert(len(item) == dim) assert(len(item) == dim)
assert(sum(item) == order) assert(sum(item) == order)
yield item yield item
...@@ -155,7 +107,7 @@ def extendMomentsWithPermutations(exponentTuples): ...@@ -155,7 +107,7 @@ def extendMomentsWithPermutations(exponentTuples):
allMoments = [] allMoments = []
for i in exponentTuples: for i in exponentTuples:
allMoments += list(momentPermutations(i)) allMoments += list(momentPermutations(i))
return __uniqueList(allMoments) return __unique(allMoments)
# ------------------------------ Representation Conversions ------------------------------------------------------------ # ------------------------------ Representation Conversions ------------------------------------------------------------
...@@ -278,8 +230,7 @@ def getExponentTupleFromIndices(momentIndices, dim): ...@@ -278,8 +230,7 @@ def getExponentTupleFromIndices(momentIndices, dim):
def getOrder(moment): def getOrder(moment):
""" """Computes polynomial order of given moment.
Computes polynomial order of given moment
Examples: Examples:
>>> x , y, z = MOMENT_SYMBOLS >>> x , y, z = MOMENT_SYMBOLS
...@@ -294,9 +245,32 @@ def getOrder(moment): ...@@ -294,9 +245,32 @@ def getOrder(moment):
return sum(moment) return sum(moment)
if len(moment.atoms(sp.Symbol)) == 0: if len(moment.atoms(sp.Symbol)) == 0:
return 0 return 0
leadingCoefficient = sp.polys.polytools.LM(moment) leading_coefficient = sp.polys.polytools.LM(moment)
symbolsInLeadingCoefficient = leadingCoefficient.atoms(sp.Symbol) symbols_in_leading_coefficient = leading_coefficient.atoms(sp.Symbol)
return sum([sp.degree(leadingCoefficient, gen=m) for m in symbolsInLeadingCoefficient]) return sum([sp.degree(leading_coefficient, gen=m) for m in symbols_in_leading_coefficient])
def nonAliasedMoment(moment_tuple: Sequence[int]) -> Tuple[int, ...]:
"""Takes a moment exponent tuple and returns the non-aliased version of it.
For first neighborhood stencils, all moments with exponents 3, 5, 7... are equal to exponent 1,
and exponents 4, 6, 8... are equal to exponent 2. This is because first neighborhood stencils only have values
d ∈ {+1, 0, -1}. So for example d**5 is always the same as d**3 and d, and d**6 == d**4 == d**2
Example:
>>> nonAliasedMoment((5, 4, 2))
(1, 2, 2)
>>> nonAliasedMoment((9, 1, 2))
(1, 1, 2)
"""
moment = list(moment_tuple)
result = []
for element in moment:
if element > 2:
result.append(2 - (element % 2))
else:
result.append(element)
return tuple(result)
def isShearMoment(moment): def isShearMoment(moment):
...@@ -605,3 +579,93 @@ def monomialToPolynomialTransformationMatrix(monomials, polynomials): ...@@ -605,3 +579,93 @@ def monomialToPolynomialTransformationMatrix(monomials, polynomials):
exponentTuple = exponentTuple[:dim] exponentTuple = exponentTuple[:dim]
result[polynomialIdx, monomials.index(exponentTuple)] = factor result[polynomialIdx, monomials.index(exponentTuple)] = factor
return result return result
# --------------------------------------- Internal Functions -----------------------------------------------------------
def __unique(seq: Sequence[T]) -> List[T]:
"""Removes duplicates from a sequence in an order preserving way.
Example:
>>> __unique((1, 2, 3, 1, 4, 6, 3))
[1, 2, 3, 4, 6]
"""
seen = {}
result = []
for item in seq:
if item in seen:
continue
seen[item] = 1
result.append(item)
return result
def __unique_permutations(elements: Sequence[T]) -> Iterable[T]:
"""Generates all unique permutations of the passed sequence.
Example:
>>> list(__unique_permutations([1, 1, 2]))
[(1, 1, 2), (1, 2, 1), (2, 1, 1)]
"""
if len(elements) == 1:
yield (elements[0],)
else:
unique_elements = set(elements)
for first_element in unique_elements:
remaining_elements = list(elements)
remaining_elements.remove(first_element)
for sub_permutation in __unique_permutations(remaining_elements):
yield (first_element,) + sub_permutation
def __fixed_sum_tuples(tuple_length: int, tuple_sum: int,
allowed_values: Optional[Sequence[int]] = None,
ordered: bool = False) -> Iterable[Tuple[int, ...]]:
"""Generates all possible tuples of positive integers with a fixed sum of all entries.
Args:
tuple_length: length of the returned tuples
tuple_sum: summing over the entries of a yielded tuple should give this number
allowed_values: optional sequence of positive integers that are considered as tuple entries
zero has to be in the set of allowed values
if None, all possible positive integers are allowed
ordered: if True, only tuples are returned where the entries are descending, i.e. where the entries are ordered
Examples:
Generate all 2-tuples where the sum of entries is 3
>>> list(__fixed_sum_tuples(tuple_length=2, tuple_sum=3))
[(0, 3), (1, 2), (2, 1), (3, 0)]
Same with ordered tuples only
>>> list(__fixed_sum_tuples(tuple_length=2, tuple_sum=3, ordered=True))
[(2, 1), (3, 0)]
Restricting the allowed values, note that zero has to be in the allowed values!
>>> list(__fixed_sum_tuples(tuple_length=3, tuple_sum=4, allowed_values=(0, 1, 3)))
[(0, 1, 3), (0, 3, 1), (1, 0, 3), (1, 3, 0), (3, 0, 1), (3, 1, 0)]
"""
if not allowed_values:
allowed_values = set(range(0, tuple_sum + 1))
assert 0 in allowed_values and all(i >= 0 for i in allowed_values)
def recursive_helper(current_list, position, total_sum):
new_position = position + 1
if new_position < len(current_list):
for i in allowed_values:
current_list[position] = i
new_sum = total_sum - i
if new_sum < 0:
continue
for item in recursive_helper(current_list, new_position, new_sum):
yield item
else:
if total_sum in allowed_values:
current_list[-1] = total_sum
if not ordered:
yield tuple(current_list)
if ordered and current_list == sorted(current_list, reverse=True):
yield tuple(current_list)
return recursive_helper([0] * tuple_length, 0, tuple_sum)
\ No newline at end of file
import sympy as sp import sympy as sp
import numpy as np import numpy as np
from pystencils.jupytersetup import *
from lbmpy.scenarios import * from lbmpy.scenarios import *
from lbmpy.creationfunctions import * from lbmpy.creationfunctions import *
from pystencils import makeSlice, showCode from pystencils import makeSlice, showCode
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment