diff --git a/creationfunctions.py b/creationfunctions.py index f0d545fe469a7919d781459ea7ab2bf3d7ace281..5f02e262ee23ec8a5ab55d7450caf8f1563e96b4 100644 --- a/creationfunctions.py +++ b/creationfunctions.py @@ -266,11 +266,8 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs): lb_method = create_lb_method(**params) split_inner_loop = 'split' in opt_params and opt_params['split'] - - dir_cse = 'cse_pdfs' - cse_pdfs = False if dir_cse not in opt_params else opt_params[dir_cse] - cse_global = False if 'cse_global' not in opt_params else opt_params['cse_global'] - simplification = create_simplification_strategy(lb_method, cse_pdfs, cse_global, split_inner_loop) + simplification = create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False, + split_inner_loop=split_inner_loop) cqc = lb_method.conserved_quantity_computation if params['velocity_input'] is not None: @@ -303,6 +300,15 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs): if 'split_groups' in collision_rule.simplification_hints: collision_rule.simplification_hints['split_groups'][0].append(sp.Symbol("smagorinsky_omega")) + cse_pdfs = False if 'cse_pdfs' not in opt_params else opt_params['cse_pdfs'] + cse_global = False if 'cse_global' not in opt_params else opt_params['cse_global'] + if cse_pdfs: + from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions + collision_rule = cse_in_opposing_directions(collision_rule) + if cse_global: + from pystencils.simp import sympy_cse + collision_rule = sympy_cse(collision_rule) + return collision_rule diff --git a/methods/entropic_eq_srt.py b/methods/entropic_eq_srt.py index eb2116b57371f48c91d4e3a051c17819c0a175a2..448be382950b7d4c1802835e90898e1013042147 100644 --- a/methods/entropic_eq_srt.py +++ b/methods/entropic_eq_srt.py @@ -65,8 +65,9 @@ class EntropicEquilibriumSRT(AbstractLbMethod): all_subexpressions += force_subexpressions collision_eqs = [Assignment(eq.lhs, eq.rhs + force_term_symbol) for eq, force_term_symbol in zip(collision_eqs, force_term_symbols)] - - return LbmCollisionRule(self, collision_eqs, all_subexpressions) + cr = LbmCollisionRule(self, collision_eqs, all_subexpressions) + cr.simplification_hints['relaxation_rates'] = [] + return cr def get_collision_rule(self): return self._get_collision_rule_with_relaxation_rate(self._relaxationRate) diff --git a/methods/momentbasedsimplifications.py b/methods/momentbasedsimplifications.py index 10bbd488c4a3c56e5a573679017bda3c7a584104..1888f3536f42189851824bc6ae7d945a5bc1dd45 100644 --- a/methods/momentbasedsimplifications.py +++ b/methods/momentbasedsimplifications.py @@ -147,9 +147,12 @@ def cse_in_opposing_directions(cr: LbmCollisionRule): """ sh = cr.simplification_hints assert 'relaxation_rates' in sh, "Needs simplification hint 'relaxation_rates': Sequence of relaxation rates" - update_rules = cr.main_assignments stencil = cr.method.stencil + + if not sh['relaxation_rates']: + return cr + relaxation_rates = sp.Matrix(sh['relaxation_rates']).atoms(sp.Symbol) replacement_symbol_generator = cr.subexpression_symbol_generator