From 81c6c262e7b1bff845eb6b08a40e797addfba5ac Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Tue, 14 Jan 2020 10:30:01 +0100
Subject: [PATCH] lbmpy: Simplifications change

- also for default MRT methods its better to use symbolic RR's instead of inserting numeric values for them
- but then the simplification strategy developed for SRT/TRT has to be disabled
- this commit introduces two new options:
     - simplification='auto' or a custom SimplificationStrategy
     - 'keep_rrs_symbolic'=True, by default now RRs are left symbolic
---
 lbmpy/creationfunctions.py                  | 16 +++++--
 lbmpy/methods/momentbased.py                | 11 ++---
 lbmpy/simplificationfactory.py              | 51 ++++++++++-----------
 lbmpy_tests/test_srt_trt_simplifications.py |  7 +--
 4 files changed, 44 insertions(+), 41 deletions(-)

diff --git a/lbmpy/creationfunctions.py b/lbmpy/creationfunctions.py
index 477bea98..57b95331 100644
--- a/lbmpy/creationfunctions.py
+++ b/lbmpy/creationfunctions.py
@@ -198,6 +198,7 @@ from pystencils import Assignment, AssignmentCollection, create_kernel
 from pystencils.cache import disk_cache_no_fallback
 from pystencils.data_types import collate_types
 from pystencils.field import Field, get_layout_of_array
+from pystencils.simp import sympy_cse
 from pystencils.stencil import have_same_entries
 
 
@@ -300,8 +301,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
         lb_method = create_lb_method(**params)
 
     split_inner_loop = 'split' in opt_params and opt_params['split']
-    simplification = create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False,
-                                                    split_inner_loop=split_inner_loop)
     cqc = lb_method.conserved_quantity_computation
 
     rho_in = params['density_input']
@@ -312,17 +311,23 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
     if rho_in is not None and isinstance(rho_in, Field):
         rho_in = rho_in.center
 
+    keep_rrs_symbolic = opt_params['keep_rrs_symbolic']
     if u_in is not None:
         density_rhs = sum(lb_method.pre_collision_pdf_symbols) if rho_in is None else rho_in
         eqs = [Assignment(cqc.zeroth_order_moment_symbol, density_rhs)]
         eqs += [Assignment(u_sym, u_in[i]) for i, u_sym in enumerate(cqc.first_order_moment_symbols)]
         eqs = AssignmentCollection(eqs, [])
-        collision_rule = lb_method.get_collision_rule(conserved_quantity_equations=eqs)
+        collision_rule = lb_method.get_collision_rule(conserved_quantity_equations=eqs,
+                                                      keep_rrs_symbolic=keep_rrs_symbolic)
     elif u_in is None and rho_in is not None:
         raise ValueError("When setting 'density_input' parameter, 'velocity_input' has to be specified as well.")
     else:
-        collision_rule = lb_method.get_collision_rule()
+        collision_rule = lb_method.get_collision_rule(keep_rrs_symbolic=keep_rrs_symbolic)
 
+    if opt_params['simplification'] == 'auto':
+        simplification = create_simplification_strategy(lb_method, split_inner_loop=split_inner_loop)
+    else:
+        simplification = opt_params['simplification']
     collision_rule = simplification(collision_rule)
 
     if params['fluctuating']:
@@ -353,7 +358,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
         from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
         collision_rule = cse_in_opposing_directions(collision_rule)
     if cse_global:
-        from pystencils.simp import sympy_cse
         collision_rule = sympy_cse(collision_rule)
 
     if params['output'] and params['kernel_type'] == 'stream_pull_collide':
@@ -550,6 +554,8 @@ def update_with_default_parameters(params, opt_params=None, fail_on_unknown_para
     default_optimization_description = {
         'cse_pdfs': False,
         'cse_global': False,
+        'simplification': 'auto',
+        'keep_rrs_symbolic': True,
         'split': False,
 
         'field_size': None,
diff --git a/lbmpy/methods/momentbased.py b/lbmpy/methods/momentbased.py
index 50f0ef87..dfde11c5 100644
--- a/lbmpy/methods/momentbased.py
+++ b/lbmpy/methods/momentbased.py
@@ -88,9 +88,9 @@ class MomentBasedLbMethod(AbstractLbMethod):
         equilibrium = self.get_equilibrium()
         return sp.Matrix([eq.rhs for eq in equilibrium.main_assignments])
 
-    def get_collision_rule(self, conserved_quantity_equations=None):
+    def get_collision_rule(self, conserved_quantity_equations=None, keep_rrs_symbolic=True):
         d = sp.diag(*self.relaxation_rates)
-        relaxation_rate_sub_expressions, d = self._generate_relaxation_matrix(d)
+        relaxation_rate_sub_expressions, d = self._generate_relaxation_matrix(d, keep_rrs_symbolic)
         ac = self._collision_rule_with_relaxation_matrix(d, relaxation_rate_sub_expressions,
                                                          True, conserved_quantity_equations)
         return ac
@@ -216,16 +216,15 @@ class MomentBasedLbMethod(AbstractLbMethod):
                                 simplification_hints)
 
     @staticmethod
-    def _generate_relaxation_matrix(relaxation_matrix):
+    def _generate_relaxation_matrix(relaxation_matrix, keep_rr_symbolic):
         """
         For SRT and TRT the equations can be easier simplified if the relaxation times are symbols, not numbers.
         This function replaces the numbers in the relaxation matrix with symbols in this case, and returns also
          the subexpressions, that assign the number to the newly introduced symbol
         """
         rr = [relaxation_matrix[i, i] for i in range(relaxation_matrix.rows)]
-        unique_relaxation_rates = set(rr)
-        if len(unique_relaxation_rates) <= 2:
-            # special handling for SRT and TRT
+        if keep_rr_symbolic <= 2:
+            unique_relaxation_rates = set(rr)
             subexpressions = {}
             for rt in unique_relaxation_rates:
                 rt = sp.sympify(rt)
diff --git a/lbmpy/simplificationfactory.py b/lbmpy/simplificationfactory.py
index bd0bec62..af3a0d58 100644
--- a/lbmpy/simplificationfactory.py
+++ b/lbmpy/simplificationfactory.py
@@ -2,42 +2,39 @@ import sympy as sp
 
 from lbmpy.innerloopsplit import create_lbm_split_groups
 from lbmpy.methods.cumulantbased import CumulantBasedLbMethod
+from lbmpy.methods.momentbased import MomentBasedLbMethod
+from lbmpy.methods.momentbasedsimplifications import (
+    factor_density_after_factoring_relaxation_times, factor_relaxation_rates,
+    replace_common_quadratic_and_constant_term, replace_density_and_velocity, replace_second_order_velocity_products)
 from pystencils.simp import (
-    add_subexpressions_for_divisions, apply_to_all_assignments,
-    subexpression_substitution_in_main_assignments, sympy_cse)
+    SimplificationStrategy, add_subexpressions_for_divisions, apply_to_all_assignments,
+    subexpression_substitution_in_main_assignments)
 
 
-def create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False, split_inner_loop=False):
-    from pystencils.simp import SimplificationStrategy
-    from lbmpy.methods.momentbased import MomentBasedLbMethod
-    from lbmpy.methods.momentbasedsimplifications import replace_second_order_velocity_products, \
-        factor_density_after_factoring_relaxation_times, factor_relaxation_rates, cse_in_opposing_directions, \
-        replace_common_quadratic_and_constant_term, replace_density_and_velocity
-
+def create_simplification_strategy(lb_method, split_inner_loop=False):
     s = SimplificationStrategy()
-
     expand = apply_to_all_assignments(sp.expand)
 
     if isinstance(lb_method, MomentBasedLbMethod):
-        s.add(expand)
-        s.add(replace_second_order_velocity_products)
-        s.add(expand)
-        s.add(factor_relaxation_rates)
-        s.add(replace_density_and_velocity)
-        s.add(replace_common_quadratic_and_constant_term)
-        s.add(factor_density_after_factoring_relaxation_times)
-        s.add(subexpression_substitution_in_main_assignments)
-        if split_inner_loop:
-            s.add(create_lbm_split_groups)
+        if len(set(lb_method.relaxation_rates)) <= 2:
+            s.add(expand)
+            s.add(replace_second_order_velocity_products)
+            s.add(expand)
+            s.add(factor_relaxation_rates)
+            s.add(replace_density_and_velocity)
+            s.add(replace_common_quadratic_and_constant_term)
+            s.add(factor_density_after_factoring_relaxation_times)
+            s.add(subexpression_substitution_in_main_assignments)
+            if split_inner_loop:
+                s.add(create_lbm_split_groups)
+            s.add(add_subexpressions_for_divisions)
+        else:
+            s.add(subexpression_substitution_in_main_assignments)
+            if split_inner_loop:
+                s.add(create_lbm_split_groups)
     elif isinstance(lb_method, CumulantBasedLbMethod):
         s.add(expand)
         s.add(factor_relaxation_rates)
-
-    s.add(add_subexpressions_for_divisions)
-
-    if cse_pdfs:
-        s.add(cse_in_opposing_directions)
-    if cse_global:
-        s.add(sympy_cse)
+        s.add(add_subexpressions_for_divisions)
 
     return s
diff --git a/lbmpy_tests/test_srt_trt_simplifications.py b/lbmpy_tests/test_srt_trt_simplifications.py
index 0ab30686..b54f3507 100644
--- a/lbmpy_tests/test_srt_trt_simplifications.py
+++ b/lbmpy_tests/test_srt_trt_simplifications.py
@@ -6,14 +6,15 @@ import sympy as sp
 
 from lbmpy.forcemodels import Guo
 from lbmpy.methods import create_srt, create_trt, create_trt_with_magic_number
+from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
 from lbmpy.simplificationfactory import create_simplification_strategy
 from lbmpy.stencils import get_stencil
 
 
 def check_method(method, limits_default, limits_cse):
-    strategy = create_simplification_strategy(method, cse_pdfs=False)
-    strategy_with_cse = create_simplification_strategy(method, cse_pdfs=True)
-
+    strategy = create_simplification_strategy(method)
+    strategy_with_cse = create_simplification_strategy(method)
+    strategy_with_cse = cse_in_opposing_directions(strategy_with_cse)
     collision_rule = method.get_collision_rule()
 
     ops_default = strategy(collision_rule).operation_count
-- 
GitLab