diff --git a/chapman_enskog/derivative.py b/chapman_enskog/derivative.py
index b0b39d746d8d5259d8fc3a84b2965c67aba5da16..6a07e3a3a2c4d10fbe3b55220c2a9f2818644a6f 100644
--- a/chapman_enskog/derivative.py
+++ b/chapman_enskog/derivative.py
@@ -1,302 +1,4 @@
-import sympy as sp
-from collections import namedtuple, defaultdict
-from pystencils.sympyextensions import normalizeProduct, prod
-
-
-def defaultDiffSortKey(d):
-    return str(d.ceIdx), str(d.label)
-
-
-class DiffOperator(sp.Expr):
-    """
-    Un-applied differential, i.e. differential operator
-    Its args are:
-        - label: the differential is w.r.t to this label / variable. 
-                 This label is mainly for display purposes (its the subscript) and to distinguish DiffOperators
-                 If the label is '-1' no subscript is displayed
-        - ceIdx: expansion order index in the Chapman Enskog expansion. It is displayed as superscript.
-                 and not displayed if set to '-1'
-    The DiffOperator behaves much like a variable with special name. Its main use is to be applied later, using the
-    DiffOperator.apply(expr, arg) which transforms 'DiffOperator's to applied 'Diff's         
-    """
-    is_commutative = True
-    is_number = False
-    is_Rational = False
-
-    def __new__(cls, label=-1, ceIdx=-1, **kwargs):
-        return sp.Expr.__new__(cls, sp.sympify(label), sp.sympify(ceIdx), **kwargs)
-
-    @property
-    def label(self):
-        return self.args[0]
-
-    @property
-    def ceIdx(self):
-        return self.args[1]
-
-    def _latex(self, printer, *args):
-        result = "{\partial"
-        if self.ceIdx >= 0:
-            result += "^{(%s)}" % (self.ceIdx,)
-        if self.label != -1:
-            result += "_{%s}" % (self.label,)
-        result += "}"
-        return result
-
-    @staticmethod
-    def apply(expr, argument):
-        """
-        Returns a new expression where each 'DiffOperator' is replaced by a 'Diff' node.
-        Multiplications of 'DiffOperator's are interpreted as nested application of differentiation:
-        i.e. DiffOperator('x')*DiffOperator('x') is a second derivative replaced by Diff(Diff(arg, x), t)
-        """
-        def handleMul(mul):
-            args = normalizeProduct(mul)
-            diffs = [a for a in args if isinstance(a, DiffOperator)]
-            if len(diffs) == 0:
-                return mul
-            rest = [a for a in args if not isinstance(a, DiffOperator)]
-            diffs.sort(key=defaultDiffSortKey)
-            result = argument
-            for d in reversed(diffs):
-                result = Diff(result, label=d.label, ceIdx=d.ceIdx)
-            return prod(rest) * result
-
-        expr = expr.expand()
-        if expr.func == sp.Mul or expr.func == sp.Pow:
-            return handleMul(expr)
-        elif expr.func == sp.Add:
-            return expr.func(*[handleMul(a) for a in expr.args])
-        else:
-            return expr
-
-
-class Diff(sp.Expr):
-    """
-    Sympy Node representing a derivative. The difference to sympy's built in differential is:
-        - shortened latex representation 
-        - all simplifications have to be done manually
-        - each Diff has a Chapman Enskog expansion order index: 'ceIdx'
-    """
-    is_number = False
-    is_Rational = False
-
-    def __new__(cls, argument, label=-1, ceIdx=-1, **kwargs):
-        if argument == 0:
-            return sp.Rational(0, 1)
-        return sp.Expr.__new__(cls, argument.expand(), sp.sympify(label), sp.sympify(ceIdx), **kwargs)
-
-    @property
-    def is_commutative(self):
-        anyNonCommutative = any(not s.is_commutative for s in self.atoms(sp.Symbol))
-        if anyNonCommutative:
-            return False
-        else:
-            return True
-
-    def getArgRecursive(self):
-        """Returns the argument the derivative acts on, for nested derivatives the inner argument is returned"""
-        if not isinstance(self.arg, Diff):
-            return self.arg
-        else:
-            return self.arg.getArgRecursive()
-
-    def changeArgRecursive(self, newArg):
-        """Returns a Diff node with the given 'newArg' instead of the current argument. For nested derivatives 
-        a new nested derivative is returned where the inner Diff has the 'newArg'"""
-        if not isinstance(self.arg, Diff):
-            return Diff(newArg, self.label, self.ceIdx)
-        else:
-            return Diff(self.arg.changeArgRecursive(newArg), self.label, self.ceIdx)
-
-    def splitLinear(self, functions):
-        """
-        Applies linearity property of Diff: i.e.  'Diff(c*a+b)' is transformed to 'c * Diff(a) + Diff(b)'
-        The parameter functions is a list of all symbols that are considered functions, not constants.
-        For the example above: functions=[a, b]
-        """
-        constant, variable = 1, 1
-
-        if self.arg.func != sp.Mul:
-            constant, variable = 1, self.arg
-        else:
-            for factor in normalizeProduct(self.arg):
-                if factor in functions or isinstance(factor, Diff):
-                    variable *= factor
-                else:
-                    constant *= factor
-
-        if isinstance(variable, sp.Symbol) and variable not in functions:
-            return 0
-
-        if isinstance(variable, int) or variable.is_number:
-            return 0
-        else:
-            return constant * Diff(variable, label=self.label, ceIdx=self.ceIdx)
-
-    @property
-    def arg(self):
-        """Expression the derivative acts on"""
-        return self.args[0]
-
-    @property
-    def label(self):
-        """Subscript, usually the variable the Diff is w.r.t. """
-        return self.args[1]
-
-    @property
-    def ceIdx(self):
-        """Superscript, used as the Chapman Enskog order index"""
-        return self.args[2]
-
-    def _latex(self, printer, *args):
-        result = "{\partial"
-        if self.ceIdx >= 0:
-            result += "^{(%s)}" % (self.ceIdx,)
-        if self.label != -1:
-            result += "_{%s}" % (printer.doprint(self.label),)
-
-        contents = printer.doprint(self.arg)
-        if isinstance(self.arg, int) or isinstance(self.arg, sp.Symbol) or self.arg.is_number or self.arg.func == Diff:
-            result += " " + contents
-        else:
-            result += " (" + contents + ") "
-
-        result += "}"
-        return result
-
-    def __str__(self):
-        #return "Diff(%s, %s, %s)" % (self.arg, self.label, self.ceIdx)
-        return "D(%s)" % (self.arg)
-
-
-# ----------------------------------------------------------------------------------------------------------------------
-
-def derivativeTerms(expr):
-    """
-    Returns set of all derivatives in an expression
-    this is different from `expr.atoms(Diff)` when nested derivatives are in the expression, 
-    since this function only returns the outer derivatives
-    """
-    result = set()
-
-    def visit(e):
-        if isinstance(e, Diff):
-            result.add(e)
-        else:
-            for a in e.args:
-                visit(a)
-    visit(expr)
-    return result
-
-
-def collectDerivatives(expr):
-    """Rewrites expression into a sum of distinct derivatives with prefactors"""
-    return expr.collect(derivativeTerms(expr))
-
-
-def createNestedDiff(*args, arg=None):
-    """Shortcut to create nested derivatives"""
-    assert arg is not None
-    args = sorted(args, reverse=True)
-    res = arg
-    for i in args:
-        res = Diff(res, i)
-    return res
-
-
-def expandUsingLinearity(expr, functions=None, constants=None):
-    """
-    Expands all derivative nodes by applying Diff.splitLinear
-    :param expr: expression containing derivatives
-    :param functions: sequence of symbols that are considered functions and can not be pulled before the derivative.
-                      if None, all symbols are viewed as functions
-    :param constants: sequence of symbols which are considered constants and can be pulled before the derivative
-    """
-    if functions is None:
-        functions = expr.atoms(sp.Symbol)
-        if constants is not None:
-            functions.difference_update(constants)
-
-    if isinstance(expr, Diff):
-        arg = expandUsingLinearity(expr.arg, functions)
-        if hasattr(arg, 'func') and arg.func == sp.Add:
-            result = 0
-            for a in arg.args:
-                result += Diff(a, label=expr.label, ceIdx=expr.ceIdx).splitLinear(functions)
-            return result
-        else:
-            diff = Diff(arg, label=expr.label, ceIdx=expr.ceIdx)
-            if diff == 0:
-                return 0
-            else:
-                return diff.splitLinear(functions)
-    else:
-        newArgs = [expandUsingLinearity(e, functions) for e in expr.args]
-        result = sp.expand(expr.func(*newArgs) if newArgs else expr)
-        return result
-
-
-def fullDiffExpand(expr, functions=None, constants=None):
-    if functions is None:
-        functions = expr.atoms(sp.Symbol)
-        if constants is not None:
-            functions.difference_update(constants)
-
-    def visit(e):
-        e = e.expand()
-
-        if e.func == Diff:
-            result = 0
-            diffArgs = {'label': e.label, 'ceIdx': e.ceIdx}
-            diffInner = e.args[0]
-            diffInner = visit(diffInner)
-            for term in diffInner.args if diffInner.func == sp.Add else [diffInner]:
-                independentTerms = 1
-                dependentTerms = []
-                for factor in normalizeProduct(term):
-                    if factor in functions or isinstance(factor, Diff):
-                        dependentTerms.append(factor)
-                    else:
-                        independentTerms *= factor
-                for i in range(len(dependentTerms)):
-                    dependentTerm = dependentTerms[i]
-                    otherDependentTerms = dependentTerms[:i] + dependentTerms[i+1:]
-                    processedDiff = normalizeDiffOrder(Diff(dependentTerm, **diffArgs))
-                    result += independentTerms * prod(otherDependentTerms) * processedDiff
-            return result
-        else:
-            newArgs = [visit(arg) for arg in e.args]
-            return e.func(*newArgs) if newArgs else e
-
-    if isinstance(expr, sp.Matrix):
-        return expr.applyfunc(visit)
-    else:
-        return visit(expr)
-
-
-def normalizeDiffOrder(expression, functions=None, constants=None, sortKey=defaultDiffSortKey):
-    """Assumes order of differentiation can be exchanged. Changes the order of nested Diffs to a standard order defined
-    by the sorting key 'sortKey' such that the derivative terms can be further simplified """
-    def visit(expr):
-        if isinstance(expr, Diff):
-            nodes = [expr]
-            while isinstance(nodes[-1].arg, Diff):
-                nodes.append(nodes[-1].arg)
-
-            processedArg = visit(nodes[-1].arg)
-            nodes.sort(key=sortKey)
-
-            result = processedArg
-            for d in reversed(nodes):
-                result = Diff(result, label=d.label, ceIdx=d.ceIdx)
-            return result
-        else:
-            newArgs = [visit(e) for e in expr.args]
-            return expr.func(*newArgs) if newArgs else expr
-
-    expression = expandUsingLinearity(expression.expand(), functions, constants).expand()
-    return visit(expression)
+from pystencils.derivative import *
 
 
 def chapmanEnskogDerivativeExpansion(expr, label, eps=sp.Symbol("epsilon"), startOrder=1, stopOrder=4):
@@ -316,148 +18,4 @@ def chapmanEnskogDerivativeRecombination(expr, label, eps=sp.Symbol("epsilon"),
         substitution = Diff(d.arg, label)
         substitution -= sum([eps ** i * Diff(d.arg, label, i) for i in range(startOrder, stopOrder - 1)])
         expr = expr.subs(d, substitution / eps**(stopOrder-1))
-    return expr
-
-
-def expandUsingProductRule(expr):
-    """Fully expands all derivatives by applying product rule"""
-    if isinstance(expr, Diff):
-        arg = expandUsingProductRule(expr.args[0])
-        if arg.func == sp.Add:
-            newArgs = [Diff(e, label=expr.label, ceIdx=expr.ceIdx)
-                       for e in arg.args]
-            return sp.Add(*newArgs)
-        if arg.func not in (sp.Mul, sp.Pow):
-            return Diff(arg, label=expr.label, ceIdx=expr.ceIdx)
-        else:
-            prodList = normalizeProduct(arg)
-            result = 0
-            for i in range(len(prodList)):
-                preFactor = prod(prodList[j] for j in range(len(prodList)) if i != j)
-                result += preFactor * Diff(prodList[i], label=expr.label, ceIdx=expr.ceIdx)
-            return result
-    else:
-        newArgs = [expandUsingProductRule(e) for e in expr.args]
-        return expr.func(*newArgs) if newArgs else expr
-
-
-def combineUsingProductRule(expr):
-    """Inverse product rule"""
-
-    def exprToDiffDecomposition(expr):
-        """Decomposes a sp.Add node containing CeDiffs into:
-        diffDict: maps (label, ceIdx) -> [ (preFactor, argument), ... ]
-        i.e.  a partial(b) ( a is prefactor, b is argument)
-            in case of partial(a) partial(b) two entries are created  (0.5 partial(a), b), (0.5 partial(b), a) 
-        """
-        DiffInfo = namedtuple("DiffInfo", ["label", "ceIdx"])
-
-        class DiffSplit:
-            def __init__(self, preFactor, argument):
-                self.preFactor = preFactor
-                self.argument = argument
-
-            def __repr__(self):
-                return str((self.preFactor, self.argument))
-
-        assert isinstance(expr, sp.Add)
-        diffDict = defaultdict(list)
-        rest = 0
-        for term in expr.args:
-            if isinstance(term, Diff):
-                diffDict[DiffInfo(term.label, term.ceIdx)].append(DiffSplit(1, term.arg))
-            else:
-                mulArgs = normalizeProduct(term)
-                diffs = [d for d in mulArgs if isinstance(d, Diff)]
-                factor = prod(d for d in mulArgs if not isinstance(d, Diff))
-                if len(diffs) == 0:
-                    rest += factor
-                else:
-                    for i, diff in enumerate(diffs):
-                        allButCurrent = [d for j, d in enumerate(diffs) if i != j]
-                        preFactor = factor * prod(allButCurrent) * sp.Rational(1, len(diffs))
-                        diffDict[DiffInfo(diff.label, diff.ceIdx)].append(DiffSplit(preFactor, diff.arg))
-
-        return diffDict, rest
-
-    def matchDiffSplits(own, other):
-        ownFac = own.preFactor / other.argument
-        otherFac = other.preFactor / own.argument
-
-        if sp.count_ops(ownFac) > sp.count_ops(own.preFactor) or sp.count_ops(otherFac) > sp.count_ops(other.preFactor):
-            return None
-
-        newOtherFactor = ownFac - otherFac
-        return newOtherFactor
-
-    def processDiffList(diffList, label, ceIdx):
-        if len(diffList) == 0:
-            return 0
-        elif len(diffList) == 1:
-            return diffList[0].preFactor * Diff(diffList[0].argument, label, ceIdx)
-
-        result = 0
-        matches = []
-        for i in range(1, len(diffList)):
-            matchResult = matchDiffSplits(diffList[i], diffList[0])
-            if matchResult is not None:
-                matches.append((i, matchResult))
-
-        if len(matches) == 0:
-            result += diffList[0].preFactor * Diff(diffList[0].argument, label, ceIdx)
-        else:
-            otherIdx, matchResult = sorted(matches, key=lambda e: sp.count_ops(e[1]))[0]
-            newArgument = diffList[0].argument * diffList[otherIdx].argument
-            result += (diffList[0].preFactor / diffList[otherIdx].argument) * Diff(newArgument, label, ceIdx)
-            if matchResult == 0:
-                del diffList[otherIdx]
-            else:
-                diffList[otherIdx].preFactor = matchResult * diffList[0].argument
-        result += processDiffList(diffList[1:], label, ceIdx)
-        return result
-
-    expr = expr.expand()
-    if isinstance(expr, sp.Add):
-        diffDict, rest = exprToDiffDecomposition(expr)
-        for (label, ceIdx), diffList in diffDict.items():
-            rest += processDiffList(diffList, label, ceIdx)
-        return rest
-    else:
-        newArgs = [combineUsingProductRule(e) for e in expr.args]
-        return expr.func(*newArgs) if newArgs else expr
-
-
-def replaceDiff(expr, replacementDict):
-    """replacementDict: maps variable (label) to a new Differential operator"""
-
-    def visit(e):
-        if isinstance(e, Diff):
-            if e.label in replacementDict:
-                return DiffOperator.apply(replacementDict[e.label], visit(e.arg))
-        newArgs = [visit(arg) for arg in e.args]
-        return e.func(*newArgs) if newArgs else e
-
-    return visit(expr)
-
-
-def zeroDiffs(expr, label):
-    """Replaces all differentials with the given label by 0"""
-    def visit(e):
-        if isinstance(e, Diff):
-            if e.label == label:
-                return 0
-        newArgs = [visit(arg) for arg in e.args]
-        return e.func(*newArgs) if newArgs else e
-    return visit(expr)
-
-
-def evaluateDiffs(expr, var=None):
-    """Replaces Diff nodes by sp.diff , the free variable is either the label (if var=None) otherwise
-    the specified var"""
-    if isinstance(expr, Diff):
-        if var is None:
-            var = expr.label
-        return sp.diff(evaluateDiffs(expr.arg, var), var)
-    else:
-        newArgs = [evaluateDiffs(arg, var) for arg in expr.args]
-        return expr.func(*newArgs) if newArgs else expr
+    return expr
\ No newline at end of file
diff --git a/phasefield/analytical.py b/phasefield/analytical.py
index cb981e83de4ea66fdff53d86b2e05f82855eb6bc..977096920677bb822022540d73e963c8ceac70d8 100644
--- a/phasefield/analytical.py
+++ b/phasefield/analytical.py
@@ -1,9 +1,8 @@
 import sympy as sp
 from collections import defaultdict
 
-from lbmpy.chapman_enskog.derivative import expandUsingLinearity, Diff, fullDiffExpand
-from pystencils.equationcollection.simplifications import sympyCseOnEquationList
 from pystencils.sympyextensions import multidimensionalSummation as multiSum, normalizeProduct, prod
+from pystencils.derivative import functionalDerivative, expandUsingLinearity, Diff, fullDiffExpand
 
 orderParameterSymbolName = "phi"
 surfaceTensionSymbolName = "tau"
@@ -219,36 +218,6 @@ def substituteLaplacianBySum(eq, dim):
     return fullDiffExpand(eq.subs(substitutions))
 
 
-def functionalDerivative(functional, v, constants=None):
-    """
-    Computes functional derivative of functional with respect to v using Euler-Lagrange equation
-
-    .. math ::
-
-        \frac{\delta F}{\delta v} =
-                \frac{\partial F}{\partial v} - \nabla \cdot \frac{\partial F}{\partial \nabla v}
-
-    - assumes that gradients are represented by Diff() node (from Chapman Enskog module)
-    - Diff(Diff(r)) represents the divergence of r
-    - the constants parameter is a list with symbols not affected by the derivative. This is used for simplification
-      of the derivative terms.
-    """
-    functional = expandUsingLinearity(functional, constants=constants)
-    diffs = functional.atoms(Diff)
-
-    diffV = Diff(v)
-
-    nonDiffPart = functional.subs({d: 0 for d in diffs})
-
-    partialF_partialV = sp.diff(nonDiffPart, v)
-
-    dummy = sp.Dummy()
-    partialF_partialGradV = functional.subs(diffV, dummy).diff(dummy).subs(dummy, diffV)
-
-    result = partialF_partialV - Diff(partialF_partialGradV)
-    return expandUsingLinearity(result, constants=constants)
-
-
 def coshIntegral(f, var):
     """Integrates a function f that has exactly one cosh term, from -oo to oo, by
     substituting a new helper variable for the cosh argument"""
@@ -259,39 +228,6 @@ def coshIntegral(f, var):
     return sp.integrate(transformedInt.args[0], (transformedInt.args[1][0], -sp.oo, sp.oo))
 
 
-def finiteDifferences2ndOrder(term, dx=1):
-    """Substitutes symbolic integral of field access by second order accurate finite differences.
-    The only valid argument of Diff objects are field accesses (usually center field accesses)"""
-    def diffOrder(e):
-        if not isinstance(e, Diff):
-            return 0
-        else:
-            return 1 + diffOrder(e.args[0])
-
-    def visit(e):
-        order = diffOrder(e)
-        if order == 0:
-            paramList = [visit(a) for a in e.args]
-            return e if not paramList else e.func(*paramList)
-        elif order == 1:
-            fa = e.args[0]
-            index = e.label
-            return (fa.neighbor(index, 1) - fa.neighbor(index, -1)) / (2 * dx)
-        elif order == 2:
-            indices = sorted([e.label, e.args[0].label])
-            fa = e.args[0].args[0]
-            if indices[0] == indices[1]:
-                result = (-2 * fa + fa.neighbor(indices[0], -1) + fa.neighbor(indices[0], +1))
-            else:
-                offsets = [(1,1), [-1, 1], [1, -1], [-1, -1]]
-                result = sum(o1*o2 * fa.neighbor(indices[0], o1).neighbor(indices[1], o2) for o1, o2 in offsets) / 4
-            return result / (dx**2)
-        else:
-            raise NotImplementedError("Term contains derivatives of order > 2")
-
-    return visit(term)
-
-
 def symmetricTensorLinearization(dim):
     nextIdx = 0
     resultMap = {}
diff --git a/phasefield/kerneleqs.py b/phasefield/kerneleqs.py
index 56f84746ba0d5a9441d808b309f79fad2e34bb57..008dda4d3b677c91fd653c802f956f2936503be1 100644
--- a/phasefield/kerneleqs.py
+++ b/phasefield/kerneleqs.py
@@ -1,7 +1,7 @@
 import sympy as sp
+from pystencils.finitedifferences import Discretization2ndOrder
 from lbmpy.phasefield.analytical import chemicalPotentialsFromFreeEnergy, substituteLaplacianBySum, \
-    finiteDifferences2ndOrder, forceFromPhiAndMu, symmetricTensorLinearization, pressureTensorFromFreeEnergy, \
-    forceFromPressureTensor
+    forceFromPhiAndMu, symmetricTensorLinearization, pressureTensorFromFreeEnergy, forceFromPressureTensor
 
 
 # ---------------------------------- Kernels to compute force ----------------------------------------------------------
@@ -14,15 +14,17 @@ def muKernel(freeEnergy, orderParameters, phiField, muField, dx=1):
     chemicalPotential = chemicalPotentialsFromFreeEnergy(freeEnergy, orderParameters)
     chemicalPotential = substituteLaplacianBySum(chemicalPotential, dim)
     chemicalPotential = chemicalPotential.subs({op: phiField(i) for i, op in enumerate(orderParameters)})
-    return [sp.Eq(muField(i), finiteDifferences2ndOrder(mu_i, dx)) for i, mu_i in enumerate(chemicalPotential)]
+    discretize = Discretization2ndOrder(dx=dx)
+    return [sp.Eq(muField(i), discretize(mu_i)) for i, mu_i in enumerate(chemicalPotential)]
 
 
 def forceKernelUsingMu(forceField, phiField, muField, dx=1):
     """Computes forces using precomputed chemical potential - needs muKernel first"""
     assert muField.indexDimensions == 1
     force = forceFromPhiAndMu(phiField.vecCenter, mu=muField.vecCenter, dim=muField.spatialDimensions)
+    discretize = Discretization2ndOrder(dx=dx)
     return [sp.Eq(forceField(i),
-                  finiteDifferences2ndOrder(f_i, dx)).expand() for i, f_i in enumerate(force)]
+                  discretize(f_i)).expand() for i, f_i in enumerate(force)]
 
 
 def pressureTensorKernel(freeEnergy, orderParameters, phiField, pressureTensorField, dx=1):
@@ -30,11 +32,11 @@ def pressureTensorKernel(freeEnergy, orderParameters, phiField, pressureTensorFi
     p = pressureTensorFromFreeEnergy(freeEnergy, orderParameters, dim)
     p = p.subs({op: phiField(i) for i, op in enumerate(orderParameters)})
     indexMap = symmetricTensorLinearization(dim)
-
+    discretize = Discretization2ndOrder(dx=dx)
     eqs = []
     for index, linIndex in indexMap.items():
         eq = sp.Eq(pressureTensorField(linIndex),
-                   finiteDifferences2ndOrder(p[index], dx).expand())
+                   discretize(p[index]).expand())
         eqs.append(eq)
     return eqs
 
@@ -43,11 +45,12 @@ def forceKernelUsingPressureTensor(forceField, pressureTensorField, extraForce=N
     dim = forceField.spatialDimensions
     indexMap = symmetricTensorLinearization(dim)
 
-    p = sp.Matrix(dim, dim, lambda i, j: pressureTensorField(indexMap[i,j] if i < j else indexMap[j, i]))
+    p = sp.Matrix(dim, dim, lambda i, j: pressureTensorField(indexMap[i, j] if i < j else indexMap[j, i]))
     f = forceFromPressureTensor(p)
     if extraForce:
         f += extraForce
-    return [sp.Eq(forceField(i), finiteDifferences2ndOrder(f_i, dx).expand())
+    discretize = Discretization2ndOrder(dx=dx)
+    return [sp.Eq(forceField(i), discretize(f_i).expand())
             for i, f_i in enumerate(f)]
 
 
@@ -55,7 +58,7 @@ def forceKernelUsingPressureTensor(forceField, pressureTensorField, extraForce=N
 
 
 def cahnHilliardFdEq(phaseIdx, phi, mu, velocity, mobility, dx, dt):
-    from pystencils.finitedifferences import transient, advection, diffusion, Discretization2ndOrder
+    from pystencils.finitedifferences import transient, advection, diffusion
     cahnHilliard = transient(phi, phaseIdx) + advection(phi, velocity, phaseIdx) - diffusion(mu, mobility, phaseIdx)
     return Discretization2ndOrder(dx, dt)(cahnHilliard)
 
@@ -97,4 +100,4 @@ class CahnHilliardFDStep:
         pass
 
     def postRun(self):
-        pass
\ No newline at end of file
+        pass