From ca1041b66c5532fc120b9f1b3ef7e28c9ac0de58 Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Fri, 6 Jul 2018 10:42:07 +0200
Subject: [PATCH] Bugfix: entropic methods now work with CSE

- CSE has to be done after dynamic relaxation rate adaption
---
 creationfunctions.py                  | 16 +++++++++++-----
 methods/entropic_eq_srt.py            |  5 +++--
 methods/momentbasedsimplifications.py |  5 ++++-
 3 files changed, 18 insertions(+), 8 deletions(-)

diff --git a/creationfunctions.py b/creationfunctions.py
index f0d545fe..5f02e262 100644
--- a/creationfunctions.py
+++ b/creationfunctions.py
@@ -266,11 +266,8 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
         lb_method = create_lb_method(**params)
 
     split_inner_loop = 'split' in opt_params and opt_params['split']
-
-    dir_cse = 'cse_pdfs'
-    cse_pdfs = False if dir_cse not in opt_params else opt_params[dir_cse]
-    cse_global = False if 'cse_global' not in opt_params else opt_params['cse_global']
-    simplification = create_simplification_strategy(lb_method, cse_pdfs, cse_global, split_inner_loop)
+    simplification = create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False,
+                                                    split_inner_loop=split_inner_loop)
     cqc = lb_method.conserved_quantity_computation
 
     if params['velocity_input'] is not None:
@@ -303,6 +300,15 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
         if 'split_groups' in collision_rule.simplification_hints:
             collision_rule.simplification_hints['split_groups'][0].append(sp.Symbol("smagorinsky_omega"))
 
+    cse_pdfs = False if 'cse_pdfs' not in opt_params else opt_params['cse_pdfs']
+    cse_global = False if 'cse_global' not in opt_params else opt_params['cse_global']
+    if cse_pdfs:
+        from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
+        collision_rule = cse_in_opposing_directions(collision_rule)
+    if cse_global:
+        from pystencils.simp import sympy_cse
+        collision_rule = sympy_cse(collision_rule)
+
     return collision_rule
 
 
diff --git a/methods/entropic_eq_srt.py b/methods/entropic_eq_srt.py
index eb2116b5..448be382 100644
--- a/methods/entropic_eq_srt.py
+++ b/methods/entropic_eq_srt.py
@@ -65,8 +65,9 @@ class EntropicEquilibriumSRT(AbstractLbMethod):
             all_subexpressions += force_subexpressions
             collision_eqs = [Assignment(eq.lhs, eq.rhs + force_term_symbol)
                              for eq, force_term_symbol in zip(collision_eqs, force_term_symbols)]
-
-        return LbmCollisionRule(self, collision_eqs, all_subexpressions)
+        cr = LbmCollisionRule(self, collision_eqs, all_subexpressions)
+        cr.simplification_hints['relaxation_rates'] = []
+        return cr
 
     def get_collision_rule(self):
         return self._get_collision_rule_with_relaxation_rate(self._relaxationRate)
diff --git a/methods/momentbasedsimplifications.py b/methods/momentbasedsimplifications.py
index 10bbd488..1888f353 100644
--- a/methods/momentbasedsimplifications.py
+++ b/methods/momentbasedsimplifications.py
@@ -147,9 +147,12 @@ def cse_in_opposing_directions(cr: LbmCollisionRule):
     """
     sh = cr.simplification_hints
     assert 'relaxation_rates' in sh, "Needs simplification hint 'relaxation_rates': Sequence of relaxation rates"
-
     update_rules = cr.main_assignments
     stencil = cr.method.stencil
+
+    if not sh['relaxation_rates']:
+        return cr
+
     relaxation_rates = sp.Matrix(sh['relaxation_rates']).atoms(sp.Symbol)
 
     replacement_symbol_generator = cr.subexpression_symbol_generator
-- 
GitLab