-
Martin Bauer authored
- assignment collection - sympyextensions
Martin Bauer authored- assignment collection - sympyextensions
innerloopsplit.py 2.43 KiB
import sympy as sp
from collections import defaultdict
from pystencils import Field
def createLbmSplitGroups(lbmCollisionEqs):
"""
Creates split groups for LBM collision equations. For details about split groups see
:func:`pystencils.transformation.splitInnerLoop` .
The split groups are added as simplification hint 'splitGroups'
Split groups are created in the following way: Opposing directions are put into a single group.
The velocity subexpressions are pre-computed as well as all subexpressions which are used in all
non-center collision equations, and depend on at least one pdf.
Required simplification hints:
- velocity: sequence of velocity symbols
"""
sh = lbmCollisionEqs.simplification_hints
assert 'velocity' in sh, "Needs simplification hint 'velocity': Sequence of velocity symbols"
preCollisionSymbols = set(lbmCollisionEqs.method.preCollisionPdfSymbols)
nonCenterPostCollisionSymbols = set(lbmCollisionEqs.method.postCollisionPdfSymbols[1:])
postCollisionSymbols = set(lbmCollisionEqs.method.postCollisionPdfSymbols)
stencil = lbmCollisionEqs.method.stencil
importantSubExpressions = {e.lhs for e in lbmCollisionEqs.subexpressions
if preCollisionSymbols.intersection(lbmCollisionEqs.dependent_symbols([e.lhs]))}
otherWrittenFields = []
for eq in lbmCollisionEqs.main_assignments:
if eq.lhs not in postCollisionSymbols and isinstance(eq.lhs, Field.Access):
otherWrittenFields.append(eq.lhs)
if eq.lhs not in nonCenterPostCollisionSymbols:
continue
importantSubExpressions.intersection_update(eq.rhs.atoms(sp.Symbol))
importantSubExpressions.update(sh['velocity'])
subexpressionsToPreCompute = list(importantSubExpressions)
splitGroups = [subexpressionsToPreCompute + otherWrittenFields, ]
directionGroups = defaultdict(list)
dim = len(stencil[0])
for direction, eq in zip(stencil, lbmCollisionEqs.main_assignments):
if direction == tuple([0]*dim):
splitGroups[0].append(eq.lhs)
continue
inverseDir = tuple([-i for i in direction])
if inverseDir in directionGroups:
directionGroups[inverseDir].append(eq.lhs)
else:
directionGroups[direction].append(eq.lhs)
splitGroups += directionGroups.values()
lbmCollisionEqs.simplification_hints['splitGroups'] = splitGroups
return lbmCollisionEqs