diff --git a/methods/momentbased.py b/methods/momentbased.py
index 1627629cc657bf92e16167b36e17a94e3cf28d69..b1bdca99d8f9bee5ed187434950310c5e1b48862 100644
--- a/methods/momentbased.py
+++ b/methods/momentbased.py
@@ -11,12 +11,10 @@ from lbmpy.moments import MOMENT_SYMBOLS, momentMatrix, exponentsToPolynomialRep
 from pystencils.equationcollection import EquationCollection
 from pystencils.sympyextensions import commonDenominator
 
-
 RelaxationInfo = namedtuple('Relaxationinfo', ['equilibriumValue', 'relaxationRate'])
 
 
 class MomentBasedLbmMethod(AbstractLbmMethod):
-
     def __init__(self, stencil, momentToRelaxationInfoDict, conservedQuantityComputation, forceModel=None):
         """
         Moment based LBM is a class to represent the single (SRT), two (TRT) and multi relaxation time (MRT) methods.
@@ -63,7 +61,7 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
         undefinedEquilibriumSymbols = symbolsInEquilibriumMoments - conservedQuantities
 
         assert len(undefinedEquilibriumSymbols) == 0, "Undefined symbol(s) in equilibrium moment: %s" % \
-                                                      (undefinedEquilibriumSymbols, )
+                                                      (undefinedEquilibriumSymbols,)
 
         self._weights = None
 
@@ -94,11 +92,11 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
         return table.format(content=content, nb='style="border:none"')
 
     @property
-    def zerothOrderEquilibriumMomentSymbol(self,):
+    def zerothOrderEquilibriumMomentSymbol(self, ):
         return self._conservedQuantityComputation.definedSymbols(order=0)[1]
 
     @property
-    def firstOrderEquilibriumMomentSymbols(self,):
+    def firstOrderEquilibriumMomentSymbols(self, ):
         return self._conservedQuantityComputation.definedSymbols(order=1)[1]
 
     @property
@@ -158,6 +156,8 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
         M = self._momentMatrix
         m_eq = self._equilibriumMoments
 
+        relaxationRateSubExpressions, D = self._generateRelaxationMatrix(D)
+
         collisionRule = f + M.inv() * D * (m_eq - M * f)
         collisionEqs = [sp.Eq(lhs, rhs) for lhs, rhs in zip(self.postCollisionPdfSymbols, collisionRule)]
 
@@ -166,9 +166,36 @@ class MomentBasedLbmMethod(AbstractLbmMethod):
         simplificationHints.update(self._conservedQuantityComputation.definedSymbols())
         simplificationHints['relaxationRates'] = D.atoms(sp.Symbol)
         simplificationHints['stencil'] = self.stencil
-        return EquationCollection(collisionEqs, eqValueEqs.subexpressions + eqValueEqs.mainEquations,
+
+        allSubexpressions = relaxationRateSubExpressions + eqValueEqs.subexpressions + eqValueEqs.mainEquations
+        return EquationCollection(collisionEqs, allSubexpressions,
                                   simplificationHints)
 
+    @staticmethod
+    def _generateRelaxationMatrix(relaxationMatrix):
+        """
+        For SRT and TRT the equations can be easier simplified if the relaxation times are symbols, not numbers.
+        This function replaces the numbers in the relaxation matrix with symbols in this case, and returns also
+         the subexpressions, that assign the number to the newly introduced symbol
+        """
+        rr = [relaxationMatrix[i, i] for i in range(relaxationMatrix.rows)]
+        uniqueRelaxationRates = set(rr)
+        if len(uniqueRelaxationRates) <= 2:
+            # special handling for SRT and TRT
+            subexpressions = {}
+            for rt in uniqueRelaxationRates:
+                rt = sp.sympify(rt)
+                if not isinstance(rt, sp.Symbol):
+                    rtSymbol = sp.Symbol("rt_%d" % (len(subexpressions),))
+                    subexpressions[rt] = rtSymbol
+
+            newRR = [subexpressions[sp.sympify(e)] if sp.sympify(e) in subexpressions else e
+                     for e in rr]
+            substitutions = [sp.Eq(e[1], e[0]) for e in subexpressions.items()]
+            return substitutions, sp.diag(*newRR)
+        else:
+            return [], relaxationMatrix
+
 
 # ------------------------------------ Helper Functions ----------------------------------------------------------------
 
diff --git a/methods/momentbasedsimplifications.py b/methods/momentbasedsimplifications.py
index 702d0647394aa3750003b0f8c10b7f1a1df21d80..d0fa820070aad09a2db4df147d85837ee4af5629 100644
--- a/methods/momentbasedsimplifications.py
+++ b/methods/momentbasedsimplifications.py
@@ -93,10 +93,13 @@ def replaceDensityAndVelocity(lbmCollisionEqs):
 
 def replaceCommonQuadraticAndConstantTerm(lbmCollisionEqs):
     """
+    A common quadratic term (f_eq_common) is extracted from the collision equation for center
+    and substituted in all equations
+
     Required simplification hints:
         - density: density symbol
         - velocity: sequence of velocity symbols
-        - relaxationRates:
+        - relaxationRates: set of symbolic relaxation rates
         - stencil:
     """
     sh = lbmCollisionEqs.simplificationHints
@@ -123,6 +126,10 @@ def replaceCommonQuadraticAndConstantTerm(lbmCollisionEqs):
 def cseInOpposingDirections(lbmCollisionEqs):
     """
     Looks for common subexpressions in terms for opposing directions (e.g. north & south, top & bottom )
+
+    Required simplification hints:
+        - relaxationRates: set of symbolic relaxation rates
+        - stencil:
     """
     sh = lbmCollisionEqs.simplificationHints
     assert 'stencil' in sh, "Needs simplification hint 'stencil': Sequence of discrete velocities"
@@ -214,6 +221,6 @@ def __getCommonQuadraticAndConstantTerms(lbmCollisionEqs):
 
     for u in sh['velocity']:
         weight = weight.subs(u, 0)
-    weight = weight / sh['rho']
+    weight = weight / sh['density']
     return t / weight
 
diff --git a/simplificationfactory.py b/simplificationfactory.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a3c01d0f0e9d97ef969784ea2a2c1f6ceacafa3
--- /dev/null
+++ b/simplificationfactory.py
@@ -0,0 +1,48 @@
+from functools import partial
+import sympy as sp
+from pystencils.equationcollection.simplifications import applyOnAllEquations, \
+    subexpressionSubstitutionInMainEquations, sympyCSE, addSubexpressionsForDivisions
+
+
+def createSimplificationStrategy(lbmMethod, doCseInOpposingDirections=False, doOverallCse=False):
+    from pystencils.equationcollection import SimplificationStrategy
+    from lbmpy.methods import MomentBasedLbmMethod
+    from lbmpy.methods.momentbasedsimplifications import replaceSecondOrderVelocityProducts, \
+        factorDensityAfterFactoringRelaxationTimes, factorRelaxationRates, cseInOpposingDirections, \
+        replaceCommonQuadraticAndConstantTerm, replaceDensityAndVelocity
+
+    s = SimplificationStrategy()
+
+    if isinstance(lbmMethod, MomentBasedLbmMethod):
+        expand = partial(applyOnAllEquations, operation=sp.expand)
+        expand.__name__ = "expand"
+
+        s.add(expand)
+        s.add(replaceSecondOrderVelocityProducts)
+        s.add(expand)
+        s.add(factorRelaxationRates)
+        s.add(replaceDensityAndVelocity)
+        s.add(replaceCommonQuadraticAndConstantTerm)
+        s.add(factorDensityAfterFactoringRelaxationTimes)
+        s.add(subexpressionSubstitutionInMainEquations)
+
+    if doCseInOpposingDirections:
+        s.add(cseInOpposingDirections)
+    if doOverallCse:
+        s.add(sympyCSE)
+
+    s.add(addSubexpressionsForDivisions)
+
+    return s
+
+
+if __name__ == '__main__':
+    from lbmpy.stencils import getStencil
+    from lbmpy.methods.momentbased import createOrthogonalMRT
+
+    stencil = getStencil("D2Q9")
+    m = createOrthogonalMRT(stencil, compressible=True)
+    cr = m.getCollisionRule()
+
+    simp = createSimplificationStrategy(m)
+    simp(cr)