diff --git a/lbmpy/creationfunctions.py b/lbmpy/creationfunctions.py
index 477bea98390964fc4b0738babeac356d12fd8b7d..57b953312df9a00a16e643ce77ee760bdb6b0dcd 100644
--- a/lbmpy/creationfunctions.py
+++ b/lbmpy/creationfunctions.py
@@ -198,6 +198,7 @@ from pystencils import Assignment, AssignmentCollection, create_kernel
 from pystencils.cache import disk_cache_no_fallback
 from pystencils.data_types import collate_types
 from pystencils.field import Field, get_layout_of_array
+from pystencils.simp import sympy_cse
 from pystencils.stencil import have_same_entries
 
 
@@ -300,8 +301,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
         lb_method = create_lb_method(**params)
 
     split_inner_loop = 'split' in opt_params and opt_params['split']
-    simplification = create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False,
-                                                    split_inner_loop=split_inner_loop)
     cqc = lb_method.conserved_quantity_computation
 
     rho_in = params['density_input']
@@ -312,17 +311,23 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
     if rho_in is not None and isinstance(rho_in, Field):
         rho_in = rho_in.center
 
+    keep_rrs_symbolic = opt_params['keep_rrs_symbolic']
     if u_in is not None:
         density_rhs = sum(lb_method.pre_collision_pdf_symbols) if rho_in is None else rho_in
         eqs = [Assignment(cqc.zeroth_order_moment_symbol, density_rhs)]
         eqs += [Assignment(u_sym, u_in[i]) for i, u_sym in enumerate(cqc.first_order_moment_symbols)]
         eqs = AssignmentCollection(eqs, [])
-        collision_rule = lb_method.get_collision_rule(conserved_quantity_equations=eqs)
+        collision_rule = lb_method.get_collision_rule(conserved_quantity_equations=eqs,
+                                                      keep_rrs_symbolic=keep_rrs_symbolic)
     elif u_in is None and rho_in is not None:
         raise ValueError("When setting 'density_input' parameter, 'velocity_input' has to be specified as well.")
     else:
-        collision_rule = lb_method.get_collision_rule()
+        collision_rule = lb_method.get_collision_rule(keep_rrs_symbolic=keep_rrs_symbolic)
 
+    if opt_params['simplification'] == 'auto':
+        simplification = create_simplification_strategy(lb_method, split_inner_loop=split_inner_loop)
+    else:
+        simplification = opt_params['simplification']
     collision_rule = simplification(collision_rule)
 
     if params['fluctuating']:
@@ -353,7 +358,6 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs):
         from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
         collision_rule = cse_in_opposing_directions(collision_rule)
     if cse_global:
-        from pystencils.simp import sympy_cse
         collision_rule = sympy_cse(collision_rule)
 
     if params['output'] and params['kernel_type'] == 'stream_pull_collide':
@@ -550,6 +554,8 @@ def update_with_default_parameters(params, opt_params=None, fail_on_unknown_para
     default_optimization_description = {
         'cse_pdfs': False,
         'cse_global': False,
+        'simplification': 'auto',
+        'keep_rrs_symbolic': True,
         'split': False,
 
         'field_size': None,
diff --git a/lbmpy/methods/momentbased.py b/lbmpy/methods/momentbased.py
index 50f0ef87e3c680f7cd9ad1bf503dc123028de50e..dfde11c52406b1ac98712e4fb6aca72e447e021c 100644
--- a/lbmpy/methods/momentbased.py
+++ b/lbmpy/methods/momentbased.py
@@ -88,9 +88,9 @@ class MomentBasedLbMethod(AbstractLbMethod):
         equilibrium = self.get_equilibrium()
         return sp.Matrix([eq.rhs for eq in equilibrium.main_assignments])
 
-    def get_collision_rule(self, conserved_quantity_equations=None):
+    def get_collision_rule(self, conserved_quantity_equations=None, keep_rrs_symbolic=True):
         d = sp.diag(*self.relaxation_rates)
-        relaxation_rate_sub_expressions, d = self._generate_relaxation_matrix(d)
+        relaxation_rate_sub_expressions, d = self._generate_relaxation_matrix(d, keep_rrs_symbolic)
         ac = self._collision_rule_with_relaxation_matrix(d, relaxation_rate_sub_expressions,
                                                          True, conserved_quantity_equations)
         return ac
@@ -216,16 +216,15 @@ class MomentBasedLbMethod(AbstractLbMethod):
                                 simplification_hints)
 
     @staticmethod
-    def _generate_relaxation_matrix(relaxation_matrix):
+    def _generate_relaxation_matrix(relaxation_matrix, keep_rr_symbolic):
         """
         For SRT and TRT the equations can be easier simplified if the relaxation times are symbols, not numbers.
         This function replaces the numbers in the relaxation matrix with symbols in this case, and returns also
          the subexpressions, that assign the number to the newly introduced symbol
         """
         rr = [relaxation_matrix[i, i] for i in range(relaxation_matrix.rows)]
-        unique_relaxation_rates = set(rr)
-        if len(unique_relaxation_rates) <= 2:
-            # special handling for SRT and TRT
+        if keep_rr_symbolic <= 2:
+            unique_relaxation_rates = set(rr)
             subexpressions = {}
             for rt in unique_relaxation_rates:
                 rt = sp.sympify(rt)
diff --git a/lbmpy/simplificationfactory.py b/lbmpy/simplificationfactory.py
index bd0bec628fe4530fd21ca540e3614b83b7f8a397..af3a0d5854a19f82bc2d6967a5e082c7149bea96 100644
--- a/lbmpy/simplificationfactory.py
+++ b/lbmpy/simplificationfactory.py
@@ -2,42 +2,39 @@ import sympy as sp
 
 from lbmpy.innerloopsplit import create_lbm_split_groups
 from lbmpy.methods.cumulantbased import CumulantBasedLbMethod
+from lbmpy.methods.momentbased import MomentBasedLbMethod
+from lbmpy.methods.momentbasedsimplifications import (
+    factor_density_after_factoring_relaxation_times, factor_relaxation_rates,
+    replace_common_quadratic_and_constant_term, replace_density_and_velocity, replace_second_order_velocity_products)
 from pystencils.simp import (
-    add_subexpressions_for_divisions, apply_to_all_assignments,
-    subexpression_substitution_in_main_assignments, sympy_cse)
+    SimplificationStrategy, add_subexpressions_for_divisions, apply_to_all_assignments,
+    subexpression_substitution_in_main_assignments)
 
 
-def create_simplification_strategy(lb_method, cse_pdfs=False, cse_global=False, split_inner_loop=False):
-    from pystencils.simp import SimplificationStrategy
-    from lbmpy.methods.momentbased import MomentBasedLbMethod
-    from lbmpy.methods.momentbasedsimplifications import replace_second_order_velocity_products, \
-        factor_density_after_factoring_relaxation_times, factor_relaxation_rates, cse_in_opposing_directions, \
-        replace_common_quadratic_and_constant_term, replace_density_and_velocity
-
+def create_simplification_strategy(lb_method, split_inner_loop=False):
     s = SimplificationStrategy()
-
     expand = apply_to_all_assignments(sp.expand)
 
     if isinstance(lb_method, MomentBasedLbMethod):
-        s.add(expand)
-        s.add(replace_second_order_velocity_products)
-        s.add(expand)
-        s.add(factor_relaxation_rates)
-        s.add(replace_density_and_velocity)
-        s.add(replace_common_quadratic_and_constant_term)
-        s.add(factor_density_after_factoring_relaxation_times)
-        s.add(subexpression_substitution_in_main_assignments)
-        if split_inner_loop:
-            s.add(create_lbm_split_groups)
+        if len(set(lb_method.relaxation_rates)) <= 2:
+            s.add(expand)
+            s.add(replace_second_order_velocity_products)
+            s.add(expand)
+            s.add(factor_relaxation_rates)
+            s.add(replace_density_and_velocity)
+            s.add(replace_common_quadratic_and_constant_term)
+            s.add(factor_density_after_factoring_relaxation_times)
+            s.add(subexpression_substitution_in_main_assignments)
+            if split_inner_loop:
+                s.add(create_lbm_split_groups)
+            s.add(add_subexpressions_for_divisions)
+        else:
+            s.add(subexpression_substitution_in_main_assignments)
+            if split_inner_loop:
+                s.add(create_lbm_split_groups)
     elif isinstance(lb_method, CumulantBasedLbMethod):
         s.add(expand)
         s.add(factor_relaxation_rates)
-
-    s.add(add_subexpressions_for_divisions)
-
-    if cse_pdfs:
-        s.add(cse_in_opposing_directions)
-    if cse_global:
-        s.add(sympy_cse)
+        s.add(add_subexpressions_for_divisions)
 
     return s
diff --git a/lbmpy_tests/test_srt_trt_simplifications.py b/lbmpy_tests/test_srt_trt_simplifications.py
index 0ab3068619ef09f87b189ab94bb9df3e5dc2053e..b54f350735ccd6f6afb808b16cedc075f3120bc1 100644
--- a/lbmpy_tests/test_srt_trt_simplifications.py
+++ b/lbmpy_tests/test_srt_trt_simplifications.py
@@ -6,14 +6,15 @@ import sympy as sp
 
 from lbmpy.forcemodels import Guo
 from lbmpy.methods import create_srt, create_trt, create_trt_with_magic_number
+from lbmpy.methods.momentbasedsimplifications import cse_in_opposing_directions
 from lbmpy.simplificationfactory import create_simplification_strategy
 from lbmpy.stencils import get_stencil
 
 
 def check_method(method, limits_default, limits_cse):
-    strategy = create_simplification_strategy(method, cse_pdfs=False)
-    strategy_with_cse = create_simplification_strategy(method, cse_pdfs=True)
-
+    strategy = create_simplification_strategy(method)
+    strategy_with_cse = create_simplification_strategy(method)
+    strategy_with_cse = cse_in_opposing_directions(strategy_with_cse)
     collision_rule = method.get_collision_rule()
 
     ops_default = strategy(collision_rule).operation_count