diff --git a/cumulantlatticemodel.py b/cumulantlatticemodel.py
index a03fe6a2a10bbca4e9d89adcc4d57c6da80bee2e..8412961d1992838ceb498b4e628276225d12defd 100644
--- a/cumulantlatticemodel.py
+++ b/cumulantlatticemodel.py
@@ -9,7 +9,52 @@ from lbmpy.equilibria import getWeights
 from lbmpy.densityVelocityExpressions import getDensityVelocityExpressions
 
 
-class SimpleBoltzmannRelaxation:
+def isHydrodynamic(idx):
+    return sum(idx) == 2
+
+
+def relaxationRateFromMagicNumber(hydrodynamicOmega, magicNumber):
+    half = sp.Rational(1, 2)
+    return 1 / (magicNumber / (1 / hydrodynamicOmega - half) + half)
+
+
+def relaxationRateFromFactor(hydrodynamicOmega, factor):
+    half = sp.Rational(1, 2)
+    return 1 / ((1 / hydrodynamicOmega - half) / factor + half)
+
+
+class TRTStyleRelaxation:
+    def __init__(self, omega, factorEven=1.0, factorOdd=sp.Rational(3, 16)):
+        self._omega = omega
+        self._factorEven = factorEven
+        self._factorOdd = factorOdd
+        self.addPostCollisionsAsSubexpressions = False
+
+    def __call__(self, preCollisionSymbols, indices):
+        pre = {a: b for a, b in zip(indices, preCollisionSymbols)}
+        post = {}
+        dim = len(indices[0])
+
+        omegaEven = relaxationRateFromFactor(self._omega, self._factorEven)
+        omegaOdd = relaxationRateFromMagicNumber(self._omega, self._factorOdd)
+
+        print(omegaEven, omegaOdd)
+        maxwellBoltzmann = maxwellBoltzmannEquilibrium(dim, c_s_sq=sp.Rational(1, 3))
+        for idx, value in pre.items():
+            isEven = sum(idx) % 2 == 0
+            if isHydrodynamic(idx):
+                relaxRate = self._omega
+            elif isEven:
+                relaxRate = omegaEven
+            else:
+                relaxRate = omegaOdd
+
+            post[idx] = pre[idx] + relaxRate * (continuousCumulant(maxwellBoltzmann, idx) - pre[idx])
+
+        return [post[idx] for idx in indices]
+
+
+class AllSameRelaxation:
     def __init__(self, omega):
         self._omega = omega
         self.addPostCollisionsAsSubexpressions = False
@@ -19,21 +64,28 @@ class SimpleBoltzmannRelaxation:
         post = {}
         dim = len(indices[0])
 
-        # conserved quantities
+        maxwellBoltzmann = maxwellBoltzmannEquilibrium(dim, c_s_sq=sp.Rational(1, 3))
         for idx, value in pre.items():
-            if sum(idx) == 0 or sum(idx) == 1:
-                post[idx] = pre[idx]
+            post[idx] = pre[idx] + self._omega * (continuousCumulant(maxwellBoltzmann, idx) - pre[idx])
+
+        return [post[idx] for idx in indices]
 
-        # hydrodynamic relaxation
-        for idx in indices:
-            idxCounter = Counter(idx)
-            if len(idxCounter.keys() - set([0, 1])) == 0 and idxCounter[1] == 2:
-                post[idx] = (1 - self._omega) * pre[idx]
 
-        # set remaining values to their equilibrium value (i.e. relaxationRate=1)
+class SimpleBoltzmannRelaxation:
+    def __init__(self, omega):
+        self._omega = omega
+        self.addPostCollisionsAsSubexpressions = False
+
+    def __call__(self, preCollisionSymbols, indices):
+        pre = {a: b for a, b in zip(indices, preCollisionSymbols)}
+        post = {}
+        dim = len(indices[0])
+
         maxwellBoltzmann = maxwellBoltzmannEquilibrium(dim, c_s_sq=sp.Rational(1, 3))
         for idx, value in pre.items():
-            if idx not in post:
+            if isHydrodynamic(idx):
+                post[idx] = (1 - self._omega) * pre[idx]
+            else:
                 post[idx] = continuousCumulant(maxwellBoltzmann, idx)
 
         return [post[idx] for idx in indices]
diff --git a/plot2d.py b/plot2d.py
index d7dd77ded5b4d7af5b4c01fa758ace65ef1d2350..ca60b4e3560ec077ddaa102f34f9e93d85842b11 100644
--- a/plot2d.py
+++ b/plot2d.py
@@ -15,11 +15,31 @@ def vectorField(field, step=2, **kwargs):
 
 def vectorFieldMagnitude(field, **kwargs):
     field = norm(field, axis=2, ord=2)
-    scalarField(field, **kwargs)
+    return scalarField(field, **kwargs)
 
 
 def scalarField(field, **kwargs):
     field = removeGhostLayers(field)
     field = np.swapaxes(field, 0, 1)
-    imshow(field, origin='lower', **kwargs)
+    return imshow(field, origin='lower', **kwargs)
 
+
+def vectorFieldMagnitudeAnimation(runFunction, plotSetupFunction=lambda: None,
+                                  plotUpdateFunction=lambda: None, interval=30, frames=180, **kwargs):
+    import matplotlib.animation as animation
+
+    fig = gcf()
+    im = None
+    field = runFunction()
+    im = vectorFieldMagnitude(field, **kwargs)
+    plotSetupFunction()
+
+    def updatefig(*args):
+        field = runFunction()
+        field = norm(field, axis=2, ord=2)
+        field = np.swapaxes(field, 0, 1)
+        im.set_array(field)
+        plotUpdateFunction()
+        return im,
+
+    return animation.FuncAnimation(fig, updatefig, interval=interval, frames=frames)