Skip to content
Snippets Groups Projects
innerloopsplit.py 2.67 KiB
import sympy as sp
from collections import defaultdict
from pystencils import Field
from lbmpy.methods.abstractlbmethod import LbmCollisionRule


def create_lbm_split_groups(cr: LbmCollisionRule, opposing_directions=True):
    """
    Creates split groups for LBM collision equations. For details about split groups see
    :func:`pystencils.transformation.split_inner_loop` .
    The split groups are added as simplification hint 'split_groups'

    Split groups are created in the following way: Opposing directions are put
    into a single group if opposing_directions, else all stores are put into separate loops
    The velocity subexpressions are pre-computed as well as all subexpressions which are used in all
    non-center collision equations, and depend on at least one pdf.

    Required simplification hints:
        - velocity: sequence of velocity symbols
    """
    sh = cr.simplification_hints
    assert 'velocity' in sh, "Needs simplification hint 'velocity': Sequence of velocity symbols"

    pre_collision_symbols = set(cr.method.pre_collision_pdf_symbols)
    non_center_post_collision_symbols = set(cr.method.post_collision_pdf_symbols[1:])
    post_collision_symbols = set(cr.method.post_collision_pdf_symbols)

    stencil = cr.method.stencil

    important_sub_expressions = {e.lhs for e in cr.subexpressions
                                 if pre_collision_symbols.intersection(cr.dependent_symbols([e.lhs]))}

    other_written_fields = []
    for eq in cr.main_assignments:
        if eq.lhs not in post_collision_symbols and isinstance(eq.lhs, Field.Access):
            other_written_fields.append(eq.lhs)
        if eq.lhs not in non_center_post_collision_symbols:
            continue
        important_sub_expressions.intersection_update(eq.rhs.atoms(sp.Symbol))

    important_sub_expressions.update(sh['velocity'])

    subexpressions_to_pre_compute = list(important_sub_expressions)
    split_groups = [subexpressions_to_pre_compute + other_written_fields, ]

    direction_groups = defaultdict(list)
    dim = len(stencil[0])

    if opposing_directions:
        for direction, eq in zip(stencil, cr.main_assignments):
            if direction == tuple([0] * dim):
                split_groups[0].append(eq.lhs)
                continue

            inverse_dir = tuple([-i for i in direction])

            if inverse_dir in direction_groups:
                direction_groups[inverse_dir].append(eq.lhs)
            else:
                direction_groups[direction].append(eq.lhs)
        split_groups += direction_groups.values()
    else:
        for e in cr.main_assignments:
            split_groups.append([e.lhs])

    cr.simplification_hints['split_groups'] = split_groups
    return cr