diff --git a/chapman_enskog/chapman_enskog.py b/chapman_enskog/chapman_enskog.py
index 8c67b26f43859f3c4588cd8861226892241ca7ee..60cc31e7e757bc0efe5a00ce5106de544e860ca5 100644
--- a/chapman_enskog/chapman_enskog.py
+++ b/chapman_enskog/chapman_enskog.py
@@ -72,6 +72,8 @@ class CeMoment(sp.Symbol):
     def __new_stage2__(cls, name, momentTuple, ceIdx=-1):
         obj = super(CeMoment, cls).__xnew__(cls, name)
         obj.momentTuple = momentTuple
+        while len(obj.momentTuple) < 3:
+            obj.momentTuple = obj.momentTuple + (0,)
         obj.ceIdx = ceIdx
         return obj
 
@@ -137,8 +139,11 @@ def substituteCollisionOperatorMoments(expr, lbMethod, collisionOpMomentName='\\
 
         momentSymbols = []
         for moment, (eqValue, rr) in lbMethod.relaxationInfoDict.items():
-            momentSymbols.append(-rr * sum(coeff * CeMoment(preCollisionMomentName, momentTuple, ceMoment.ceIdx)
-                                           for coeff, momentTuple in polynomialToExponentRepresentation(moment)))
+            if isinstance(moment, tuple):
+                momentSymbols.append(-rr * CeMoment(preCollisionMomentName, moment, ceMoment.ceIdx))
+            else:
+                momentSymbols.append(-rr * sum(coeff * CeMoment(preCollisionMomentName, momentTuple, ceMoment.ceIdx)
+                                               for coeff, momentTuple in polynomialToExponentRepresentation(moment)))
         momentSymbols = sp.Matrix(momentSymbols)
         postCollisionValue = discreteMoment(tuple(Minv * momentSymbols), momentTuple, lbMethod.stencil)
         subsDict[ceMoment] = postCollisionValue
@@ -203,112 +208,6 @@ def takeMoments(eqn, pdfToMomentName=(('f', '\Pi'), ('Cf', '\\Upsilon')), veloci
         return sum(handleProduct(t) for t in eqn.args)
 
 
-#def takeMoments(eqn, pdfName="f", pdfMomentName="\\Pi", collisionOpName="Cf", collisonOpMomentName="\\Upsilon",
-#                velocityName='c', maxPdfExpansion=5):
-#    pdfToMomentName = [[tuple(expandedSymbol(pdfName, superscript=i) for i in range(maxPdfExpansion)), pdfMomentName],
-#                       [(sp.Symbol(collisionOpName),), collisonOpMomentName]]
-#
-#    pdfSymbols = [a[0] for a in pdfToMomentName]
-#
-#    velocityTerms = tuple(expandedSymbol(velocityName, subscript=i) for i in range(3))
-#
-#    def determineFIndex(factor):
-#        FIndex = namedtuple("FIndex", ['momentName', 'ceIdx'])
-#        for symbolListId, pdfSymbolsElement in enumerate(pdfSymbols):
-#            try:
-#                return FIndex(pdfToMomentName[symbolListId][1], pdfSymbolsElement.index(factor))
-#            except ValueError:
-#                pass
-#        return None
-#
-#    def handleProduct(productTerm):
-#        fIndex = None
-#        derivativeTerm = None
-#        cIndices = []
-#        rest = 1
-#        for factor in normalizeProduct(productTerm):
-#            if isinstance(factor, Diff):
-#                assert fIndex is None
-#                fIndex = determineFIndex(factor.getArgRecursive())
-#                derivativeTerm = factor
-#            elif factor in velocityTerms:
-#                cIndices += [velocityTerms.index(factor)]
-#            else:
-#                newFIndex = determineFIndex(factor)
-#                if newFIndex is None:
-#                    rest *= factor
-#                else:
-#                    assert not(newFIndex and fIndex)
-#                    fIndex = newFIndex
-#
-#        momentTuple = [0] * len(velocityTerms)
-#        for cIdx in cIndices:
-#            momentTuple[cIdx] += 1
-#        momentTuple = tuple(momentTuple)
-#
-#        result = CeMoment(fIndex.momentName, momentTuple, fIndex.ceIdx)
-#        if derivativeTerm is not None:
-#            result = derivativeTerm.changeArgRecursive(result)
-#        result *= rest
-#        return result
-#
-#    functions = sum(pdfSymbols, ())
-#    eqn = expandUsingLinearity(eqn, functions).expand()
-#
-#    if eqn.func == sp.Mul:
-#        return handleProduct(eqn)
-#    else:
-#        assert eqn.func == sp.Add
-#        return sum(handleProduct(t) for t in eqn.args)
-
-
-def takeMomentsOld(eqn, pdfExpansionTerms, velocityTerms, momentSymbolName="\Pi"):
-    if isinstance(pdfExpansionTerms, str):
-        maxExpansion = 6
-        pdfExpansionTerms = tuple(expandedSymbol(pdfExpansionTerms, superscript=i) for i in range(maxExpansion))
-
-    if isinstance(velocityTerms, str):
-        velocityTerms = tuple(expandedSymbol(velocityTerms, subscript=i) for i in range(3))
-
-    def handleProduct(productTerm):
-        fIndex = None
-        derivativeTerm = None
-        cIndices = []
-        rest = 1
-        for factor in normalizeProduct(productTerm):
-            if isinstance(factor, Diff):
-                assert fIndex is None
-                arg = factor.getArgRecursive()
-                fIndex = pdfExpansionTerms.index(arg)
-                derivativeTerm = factor
-            elif factor in pdfExpansionTerms:
-                assert fIndex is None
-                fIndex = pdfExpansionTerms.index(factor)
-            elif factor in velocityTerms:
-                cIndices += [velocityTerms.index(factor)]
-            else:
-                rest *= factor
-
-        momentTuple = [0] * len(velocityTerms)
-        for cIdx in cIndices:
-            momentTuple[cIdx] += 1
-        momentTuple = tuple(momentTuple)
-
-        result = CeMoment(momentSymbolName, momentTuple, fIndex)
-        if derivativeTerm is not None:
-            result = derivativeTerm.changeArgRecursive(result)
-        result *= rest
-        return result
-
-    eqn = expandUsingLinearity(eqn, pdfExpansionTerms).expand()
-
-    if eqn.func == sp.Mul:
-        return handleProduct(eqn)
-    else:
-        assert eqn.func == sp.Add
-        return sum(handleProduct(t) for t in eqn.args)
-
-
 def timeDiffSelector(eq):
     return [d for d in eq.atoms(Diff) if d.label == sp.Symbol("t")]
 
@@ -341,7 +240,7 @@ def chainSolveAndSubstitute(eqSequence, unknownSelector, normalizingFunc=diffExp
         symbolsToSolveFor = unknownSelector(eq)
         if len(symbolsToSolveFor) == 0:
             continue
-        assert len(symbolsToSolveFor) <= 1, "Unknown Selector return multiple unknowns - expected <=1" + str(
+        assert len(symbolsToSolveFor) <= 1, "Unknown Selector return multiple unknowns - expected <=1\n" + str(
             symbolsToSolveFor)
         symbolToSolveFor = symbolsToSolveFor[0]
         solveRes = sp.solve(eq, symbolToSolveFor)
@@ -526,7 +425,7 @@ def computeHigherOrderMomentSubsDict(momentEquations):
 
 class ChapmanEnskogAnalysis(object):
 
-    def __init__(self, method):
+    def __init__(self, method, constants=None):
         cqc = method.conservedQuantityComputation
         self._method = method
         self._momentCache = LbMethodEqMoments(method)
@@ -543,11 +442,14 @@ class ChapmanEnskogAnalysis(object):
         momentsUntilOrder1 = [1] + list(c)
         momentsOrder2 = [c_i * c_j for c_i, c_j in productSymmetric(c, c)]
 
-        oEpsMoments1 = [expandUsingLinearity(self._takeAndInsertMoments(self.equationsGroupedByOrder[1] * moment))
+        oEpsMoments1 = [expandUsingLinearity(self._takeAndInsertMoments(self.equationsGroupedByOrder[1] * moment),
+                                             constants=constants)
                         for moment in momentsUntilOrder1]
-        oEpsMoments2 = [expandUsingLinearity(self._takeAndInsertMoments(self.equationsGroupedByOrder[1] * moment))
+        oEpsMoments2 = [expandUsingLinearity(self._takeAndInsertMoments(self.equationsGroupedByOrder[1] * moment),
+                                             constants=constants)
                         for moment in momentsOrder2]
-        oEpsSqMoments1 = [expandUsingLinearity(self._takeAndInsertMoments(self.equationsGroupedByOrder[2] * moment))
+        oEpsSqMoments1 = [expandUsingLinearity(self._takeAndInsertMoments(self.equationsGroupedByOrder[2] * moment),
+                                               constants=constants)
                           for moment in momentsUntilOrder1]
 
         self._equationsWithHigherOrderMoments = [self._ceRecombine(ord1 * self.epsilon + ord2 * self.epsilon ** 2)
@@ -587,7 +489,7 @@ class ChapmanEnskogAnalysis(object):
         return expr
 
     def getDynamicViscosity(self):
-        candidates = self._getShearViscosityCandidates()
+        candidates = self.getShearViscosityCandidates()
         if len(candidates) != 1:
             raise ValueError("Could not find expression for kinematic viscosity. "
                              "Probably method does not approximate Navier Stokes.")
@@ -599,18 +501,48 @@ class ChapmanEnskogAnalysis(object):
         else:
             return self.getDynamicViscosity()
 
-    def _getShearViscosityCandidates(self):
+    def getShearViscosityCandidates(self):
         result = set()
         dim = self._method.dim
         for i, j in productSymmetric(range(dim), range(dim), withDiagonal=False):
             result.add(-sp.cancel(self._sigmaWithoutErrorTerms[i, j] / (Diff(self.u[i], j) + Diff(self.u[j], i))))
         return result
 
-    def _getBulkViscosityCandidates(self):
+    def doesApproximateNavierStokes(self):
+        """Returns a set of equations that are required in order for the method to approximate Navier Stokes equations
+        up to second order"""
+        conditions = set([0])
+        dim = self._method.dim
+        assert dim > 1
+        # Check that shear viscosity does not depend on any u derivatives - create conditions (equations) that
+        # have to be fulfilled for this to be the case
+        viscosityReference = self._sigmaWithoutErrorTerms[0, 1].expand().coeff(Diff(self.u[0], 1))
+        for i, j in productSymmetric(range(dim), range(dim), withDiagonal=False):
+            term = self._sigmaWithoutErrorTerms[i, j]
+            equalCrossTermCondition = sp.expand(term.coeff(Diff(self.u[i], j)) - viscosityReference)
+            term = term.subs({Diff(self.u[i], j): 0,
+                              Diff(self.u[j], i): 0})
+
+            conditions.add(equalCrossTermCondition)
+            for k in range(dim):
+                symmetricTermCondition = term.coeff(Diff(self.u[k], k))
+                conditions.add(symmetricTermCondition)
+            term = term.subs({Diff(self.u[k], k): 0 for k in range(dim)})
+            conditions.add(term)
+
+        bulkCandidates = list(self.getBulkViscosityCandidates(-viscosityReference))
+        if len(bulkCandidates) > 0:
+            for i in range(1, len(bulkCandidates)):
+                conditions.add(bulkCandidates[0] - bulkCandidates[i])
+
+        return conditions
+
+    def getBulkViscosityCandidates(self, viscosity=None):
         sigma = self._sigmaWithoutErrorTerms
         assert self._sigmaWithHigherOrderMoments.is_square
         result = set()
-        viscosity = self.getDynamicViscosity()
+        if viscosity is None:
+            viscosity = self.getDynamicViscosity()
         for i in range(sigma.shape[0]):
             bulkTerm = sigma[i, i] + 2 * viscosity * Diff(self.u[i], i)
             bulkTerm = bulkTerm.expand()
@@ -625,7 +557,7 @@ class ChapmanEnskogAnalysis(object):
         return result
 
     def getBulkViscosity(self):
-        candidates = self._getBulkViscosityCandidates()
+        candidates = self.getBulkViscosityCandidates()
         if len(candidates) != 1:
             raise ValueError("Could not find expression for bulk viscosity. "
                              "Probably method does not approximate Navier Stokes.")
@@ -637,3 +569,4 @@ class ChapmanEnskogAnalysis(object):
         kinematicViscosity = self.getKinematicViscosity()
         solveRes = sp.solve(kinematicViscosity - nu, kinematicViscosity.atoms(sp.Symbol), dict=True)
         return solveRes[0]
+
diff --git a/chapman_enskog/derivative.py b/chapman_enskog/derivative.py
index 27af6f1d8db0073362e606d489928bf6fcba3f50..2570ae0984e03f3b227a0438a447db36b3bae4aa 100644
--- a/chapman_enskog/derivative.py
+++ b/chapman_enskog/derivative.py
@@ -119,6 +119,9 @@ class Diff(sp.Expr):
                 else:
                     constant *= factor
 
+        if isinstance(variable, sp.Symbol) and variable not in functions:
+            return 0
+
         if isinstance(variable, int) or variable.is_number:
             return 0
         else:
@@ -162,10 +165,12 @@ class Diff(sp.Expr):
 # ----------------------------------------------------------------------------------------------------------------------
 
 
-def expandUsingLinearity(expr, functions=None):
+def expandUsingLinearity(expr, functions=None, constants=None):
     """Expands all derivative nodes by applying Diff.splitLinear"""
     if functions is None:
         functions = expr.atoms(sp.Symbol)
+        if constants is not None:
+            functions.difference_update(constants)
 
     if isinstance(expr, Diff):
         arg = expandUsingLinearity(expr.arg, functions)
@@ -175,7 +180,11 @@ def expandUsingLinearity(expr, functions=None):
                 result += Diff(a, label=expr.label, ceIdx=expr.ceIdx).splitLinear(functions)
             return result
         else:
-            return Diff(arg, label=expr.label, ceIdx=expr.ceIdx).splitLinear(functions)
+            diff = Diff(arg, label=expr.label, ceIdx=expr.ceIdx)
+            if diff == 0:
+                return 0
+            else:
+                return diff.splitLinear(functions)
     else:
         newArgs = [expandUsingLinearity(e, functions) for e in expr.args]
         result = expr.func(*newArgs) if newArgs else expr
diff --git a/methods/creationfunctions.py b/methods/creationfunctions.py
index 87433331770f668f1481993b9a29639efb20b195..0cc66cdf2149e5a57fad3be09a6f30a805cddbb9 100644
--- a/methods/creationfunctions.py
+++ b/methods/creationfunctions.py
@@ -1,7 +1,7 @@
 from warnings import warn
 
 import sympy as sp
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
 from functools import reduce
 import operator
 import itertools
@@ -10,7 +10,7 @@ from lbmpy.methods.momentbased import MomentBasedLbMethod
 from lbmpy.stencils import stencilsHaveSameEntries, getStencil
 from lbmpy.moments import isEven, gramSchmidt, getDefaultMomentSetForStencil, MOMENT_SYMBOLS, \
     exponentsToPolynomialRepresentations, momentsOfOrder, momentsUpToComponentOrder, sortMomentsIntoGroupsOfSameOrder, \
-    getOrder
+    getOrder, discreteMoment
 from pystencils.sympyextensions import commonDenominator
 from lbmpy.methods.conservedquantitycomputation import DensityVelocityComputation
 from lbmpy.methods.abstractlbmethod import RelaxationInfo
@@ -73,8 +73,7 @@ def createWithContinuousMaxwellianEqMoments(stencil, momentToRelaxationRateDict,
     By using the continuous Maxwellian we automatically get a compressible model.
     """
     momToRrDict = OrderedDict(momentToRelaxationRateDict)
-    assert len(momToRrDict) == len(
-        stencil), "The number of moments has to be the same as the number of stencil entries"
+    assert len(momToRrDict) == len(stencil), "The number of moments has to be the same as the number of stencil entries"
     dim = len(stencil[0])
     densityVelocityComputation = DensityVelocityComputation(stencil, compressible, forceModel)
 
@@ -100,6 +99,34 @@ def createWithContinuousMaxwellianEqMoments(stencil, momentToRelaxationRateDict,
         return MomentBasedLbMethod(stencil, rrDict, densityVelocityComputation, forceModel)
 
 
+def createWithGivenEqMoments(stencil, momentToEqValueDict, compressible=False, forceModel=None,
+                             momentToRelaxationRateDict=defaultdict(lambda: sp.Symbol("omega"))):
+    densityVelocityComputation = DensityVelocityComputation(stencil, compressible, forceModel)
+
+    rrDict = OrderedDict()
+    for moment in momentToEqValueDict.keys():
+        rrDict[moment] = RelaxationInfo(momentToEqValueDict[moment], momentToRelaxationRateDict[moment])
+    return MomentBasedLbMethod(stencil, rrDict, densityVelocityComputation, forceModel)
+
+
+def createFromEquilibrium(stencil, equilibrium, momentToRelaxationRateDict, compressible=False, forceModel=None):
+    r"""
+    Creates a moment-based LB method using a given equilibrium distribution function
+    :param stencil: see createWithDiscreteMaxwellianEqMoments
+    :param equilibrium: list of equilibrium terms, dependent on rho and u, one for each stencil direction
+    :param momentToRelaxationRateDict: relaxation rate for each moment
+    :param compressible: see createWithDiscreteMaxwellianEqMoments
+    :param forceModel: see createWithDiscreteMaxwellianEqMoments
+    """
+    momToRrDict = OrderedDict(momentToRelaxationRateDict)
+    assert len(momToRrDict) == len(stencil), "The number of moments has to be the same as the number of stencil entries"
+    densityVelocityComputation = DensityVelocityComputation(stencil, compressible, forceModel)
+
+    rrDict = OrderedDict([(mom, RelaxationInfo(discreteMoment(equilibrium, mom, stencil), rr))
+                          for mom, rr in zip(momToRrDict.keys(), momToRrDict.values())])
+    return MomentBasedLbMethod(stencil, rrDict, densityVelocityComputation, forceModel)
+
+
 # ------------------------------------ SRT / TRT/ MRT Creators ---------------------------------------------------------
 
 
diff --git a/methods/momentbased.py b/methods/momentbased.py
index f4594eca0a24f4c34a20d4090ce3fd56d6a19496..f85f8cf2d7d77e3a19d79e965d2c881b92ab5596 100644
--- a/methods/momentbased.py
+++ b/methods/momentbased.py
@@ -91,6 +91,10 @@ class MomentBasedLbMethod(AbstractLbMethod):
         return self._getCollisionRuleWithRelaxationMatrix(D, conservedQuantityEquations=conservedQuantityEquations,
                                                           includeForceTerms=includeForceTerms)
 
+    def getEquilibriumTerms(self):
+        equilibrium = self.getEquilibrium()
+        return sp.Matrix([eq.rhs for eq in equilibrium.mainEquations])
+
     def getCollisionRule(self, conservedQuantityEquations=None):
         D = sp.diag(*self.relaxationRates)
         relaxationRateSubExpressions, D = self._generateRelaxationMatrix(D)
diff --git a/quadratic_equilibrium_construction.py b/quadratic_equilibrium_construction.py
index 8633167bfd490ede68df2fc0fbd2b2d4404fa391..160399ae7319a3d4597732805f397c3d081c98b3 100644
--- a/quadratic_equilibrium_construction.py
+++ b/quadratic_equilibrium_construction.py
@@ -27,6 +27,16 @@ def genericEquilibriumAnsatz(stencil, u=sp.symbols("u_:3")):
     return tuple(equilibrium)
 
 
+def genericEquilibriumAnsatzParameters(stencil):
+    degreesOfFreedom = set()
+    for direction in stencil:
+        speed = np.abs(direction).sum()
+        params = getParameterSymbols(speed)
+        degreesOfFreedom.update(params)
+    degreesOfFreedom.add(sp.Symbol("p"))
+    return sorted(list(degreesOfFreedom), key=lambda e: e.name)
+
+
 def matchGenericEquilibriumAnsatz(stencil, equilibrium, u=sp.symbols("u_:3")):
     """Given a quadratic equilibrium, the generic coefficients A,B,C,D are determined. 
     Returns a dict that maps these coefficients to their values. If the equilibrium does not have a
@@ -58,12 +68,13 @@ def momentConstraintEquations(stencil, equilibrium, momentToValueDict, u=sp.symb
     passed in momentToValueDict. This dict is expected to map moment tuples to values."""
     dim = len(stencil[0])
     u = u[:dim]
+    equilibrium = tuple(equilibrium)
     constraintEquations = set()
     for moment, desiredValue in momentToValueDict.items():
         genericMoment = discreteMoment(equilibrium, moment, stencil)
         equations = sp.poly(genericMoment - desiredValue, *u).coeffs()
         constraintEquations.update(equations)
-    return constraintEquations
+    return list(constraintEquations)
 
 
 def hydrodynamicMomentValues(upToOrder=3, dim=3, compressible=True):
diff --git a/stencils.py b/stencils.py
index 9ea26613bb3adeaaf3f2159fdc665f796f747796..c2ce7b75cb0c550cd5969ed557d3611e3d8adbfd 100644
--- a/stencils.py
+++ b/stencils.py
@@ -1,3 +1,4 @@
+import sympy as sp
 
 
 def getStencil(name, ordering='walberla'):
@@ -139,13 +140,14 @@ def visualizeStencil(stencil, **kwargs):
             visualizeStencil3D(stencil, **kwargs)
 
 
-def visualizeStencil2D(stencil, axes=None, data=None):
+def visualizeStencil2D(stencil, axes=None, data=None, textsize='12', **kwargs):
     """
     Creates a matplotlib 2D plot of the stencil
 
     :param stencil: sequence of directions
     :param axes: optional matplotlib axes
     :param data: data to annotate the directions with, if none given, the indices are used
+    :param textsize: size of annotation text
     """
     from matplotlib.patches import BoxStyle
     import matplotlib.pyplot as plt
@@ -166,8 +168,13 @@ def visualizeStencil2D(stencil, axes=None, data=None):
         if not(dir[0] == 0 and dir[1] == 0):
             axes.arrow(0, 0, dir[0], dir[1], head_width=0.08, head_length=head_length, color='k')
 
-        axes.text(dir[0]*text_offset, dir[1]*text_offset, str(annotation), verticalalignment='center', zorder=30,
-                  size='12', bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
+        if isinstance(annotation, sp.Basic):
+            annotation = "$" + sp.latex(annotation) + "$"
+        else:
+            annotation = str(annotation)
+        axes.text(dir[0]*text_offset, dir[1]*text_offset, annotation, verticalalignment='center', zorder=30,
+                  horizontalalignment='center',
+                  size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
 
     axes.set_axis_off()
     axes.set_aspect('equal')
@@ -175,7 +182,7 @@ def visualizeStencil2D(stencil, axes=None, data=None):
     axes.set_ylim([-text_offset * 1.1, text_offset * 1.1])
 
 
-def visualizeStencil3DBySlicing(stencil, sliceAxis=2, data=None):
+def visualizeStencil3DBySlicing(stencil, sliceAxis=2, data=None, **kwargs):
     """
     Visualizes a 3D, first-neighborhood stencil by plotting 3 slices along a given axis
 
@@ -202,12 +209,12 @@ def visualizeStencil3DBySlicing(stencil, sliceAxis=2, data=None):
         splittedData[splitIdx].append(i if data is None else data[i])
 
     for i in range(3):
-        visualizeStencil2D(splittedDirections[i], axes[i], splittedData[i])
+        visualizeStencil2D(splittedDirections[i], axes[i], splittedData[i], **kwargs)
     for i in [-1, 0, 1]:
         axes[i+1].set_title("Cut at %s=%d" % (axesNames[sliceAxis], i))
 
 
-def visualizeStencil3D(stencil, axes=None, data=None):
+def visualizeStencil3D(stencil, axes=None, data=None, textsize='8'):
     """
     Draws 3D stencil into a 3D coordinate system, parameters are similar to :func:`visualizeStencil2D`
     If data is None, no labels are drawn. To draw the labels as in the 2D case, use ``data=list(range(len(stencil)))``
@@ -260,10 +267,16 @@ def visualizeStencil3D(stencil, axes=None, data=None):
 
             a = Arrow3D([0, dir[0]], [0, dir[1]], [0, dir[2]], mutation_scale=20, lw=2, arrowstyle="-|>", color=color)
             axes.add_artist(a)
+
         if annotation:
+            if isinstance(annotation, sp.Basic):
+                annotation = "$" + sp.latex(annotation) + "$"
+            else:
+                annotation = str(annotation)
+
             axes.text(dir[0]*text_offset, dir[1]*text_offset, dir[2]*text_offset,
-                      str(annotation), verticalalignment='center', zorder=30,
-                      size='8', bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
+                      annotation, verticalalignment='center', zorder=30,
+                      size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
 
     axes.set_xlim([-text_offset*1.1, text_offset*1.1])
     axes.set_ylim([-text_offset * 1.1, text_offset * 1.1])