diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 10e855df2fd9a661f053adbdb43dbfb8d969fdf4..b081f752945133d56a4d32407e335b26d13614ec 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -310,6 +310,7 @@ class Block(Node):
 
     def insert_before(self, new_node, insert_before):
         new_node.parent = self
+        assert self._nodes.count(insert_before) == 1
         idx = self._nodes.index(insert_before)
 
         # move all assignment (definitions to the top)
@@ -337,6 +338,7 @@ class Block(Node):
         return tmp
 
     def replace(self, child, replacements):
+        assert self._nodes.count(child) == 1
         idx = self._nodes.index(child)
         del self._nodes[idx]
         if type(replacements) is list:
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index f351ce5a2bb03d723d22a8e1f772b25a934f7994..9b119ea9a308726b8225ece66cddcb80ee3a4ef3 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -34,6 +34,7 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
                       transformation :func:`pystencils.transformation.split_inner_loop`
         iteration_slice: if not None, iteration is done only over this slice of the field
         ghost_layers: a sequence of pairs for each coordinate with lower and upper nr of ghost layers
+                      that should be excluded from the iteration.
                      if None, the number of ghost layers is determined automatically and assumed to be equal for a
                      all dimensions
         skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 866f1a4ef5357e6a1a74d21b174d24785c398677..e0c635e06220f10198383e3c1fbbe0de7561280d 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -1,13 +1,15 @@
-import itertools
 from types import MappingProxyType
+from itertools import combinations
 
 import sympy as sp
 
 from pystencils.assignment import Assignment
 from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
 from pystencils.cpu.vectorization import vectorize
+from pystencils.field import Field, FieldType
 from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
+from pystencils.stencil import direction_string_to_offset, inverse_direction_string
 from pystencils.transformations import (
     loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
 
@@ -23,7 +25,8 @@ def create_kernel(assignments,
                   cpu_blocking=None,
                   gpu_indexing='block',
                   gpu_indexing_params=MappingProxyType({}),
-                  use_textures_for_interpolation=True):
+                  use_textures_for_interpolation=True,
+                  cpu_prepend_optimizations=[]):
     """
     Creates abstract syntax tree (AST) of kernel, using a list of update equations.
 
@@ -34,9 +37,9 @@ def create_kernel(assignments,
                   to type
         iteration_slice: rectangular subset to iterate over, if not specified the complete non-ghost layer \
                          part of the field is iterated over
-        ghost_layers: if left to default, the number of necessary ghost layers is determined automatically
-                     a single integer specifies the ghost layer count at all borders, can also be a sequence of
-                     pairs ``[(x_lower_gl, x_upper_gl), .... ]``
+        ghost_layers: a single integer specifies the ghost layer count at all borders, can also be a sequence of
+                      pairs ``[(x_lower_gl, x_upper_gl), .... ]``. These layers are excluded from the iteration.
+                      If left to default, the number of ghost layers is determined automatically.
         skip_independence_check: don't check that loop iterations are independent. This is needed e.g. for
                                  periodicity kernel, that access the field outside the iteration bounds. Use with care!
         cpu_openmp: True or number of threads for OpenMP parallelization, False for no OpenMP
@@ -47,6 +50,7 @@ def create_kernel(assignments,
         gpu_indexing: either 'block' or 'line' , or custom indexing class, see `AbstractIndexing`
         gpu_indexing_params: dict with indexing parameters (constructor parameters of indexing class)
                              e.g. for 'block' one can specify '{'block_size': (20, 20, 10) }'
+        cpu_prepend_optimizations: list of extra optimizations to perform first on the AST
 
     Returns:
         abstract syntax tree (AST) object, that can either be printed as source code with `show_code` or
@@ -84,6 +88,8 @@ def create_kernel(assignments,
         ast = create_kernel(assignments, type_info=data_type, split_groups=split_groups,
                             iteration_slice=iteration_slice, ghost_layers=ghost_layers,
                             skip_independence_check=skip_independence_check)
+        for optimization in cpu_prepend_optimizations:
+            optimization(ast)
         omp_collapse = None
         if cpu_blocking:
             omp_collapse = loop_blocking(ast, cpu_blocking)
@@ -186,104 +192,132 @@ def create_indexed_kernel(assignments,
         raise ValueError("Unknown target %s. Has to be either 'cpu' or 'gpu'" % (target,))
 
 
-def create_staggered_kernel(staggered_field, expressions, subexpressions=(), target='cpu',
-                            gpu_exclusive_conditions=False, **kwargs):
+def create_staggered_kernel(assignments, target='cpu', gpu_exclusive_conditions=False, **kwargs):
     """Kernel that updates a staggered field.
 
     .. image:: /img/staggered_grid.svg
 
+    For a staggered field, the first index coordinate defines the location of the staggered value.
+    Further index coordinates can be used to store vectors/tensors at each point.
+
     Args:
-        staggered_field: field where the first index coordinate defines the location of the staggered value
-                can have 1 or 2 index coordinates, in case of two index coordinates at every staggered location
-                a vector is stored, expressions parameter has to be a sequence of sequences then
-                where e.g. ``f[0,0](0)`` is interpreted as value at the left cell boundary, ``f[1,0](0)`` the right cell
-                boundary and ``f[0,0](1)`` the southern cell boundary etc.
-        expressions: sequence of expressions of length dim, defining how the west, southern, (bottom) cell boundary
-                     should be updated.
-        subexpressions: optional sequence of Assignments, that define subexpressions used in the main expressions
-        target: 'cpu' or 'gpu'
-        gpu_exclusive_conditions: if/else construct to have only one code block for each of 2**dim code paths
-        kwargs: passed directly to create_kernel, iteration slice and ghost_layers parameters are not allowed
+        assignments: a sequence of assignments or an AssignmentCollection.
+                     Assignments to staggered field are processed specially, while subexpressions and assignments to
+                     regular fields are passed through to `create_kernel`. Multiple different staggered fields can be
+                     used, but they all need to use the same stencil (i.e. the same number of staggered points) and
+                     shape.
+        target: 'cpu', 'llvm' or 'gpu'
+        gpu_exclusive_conditions: disable the use of multiple conditionals inside the loop. The outer layers are then
+                                  handled in an else branch.
+        kwargs: passed directly to create_kernel, iteration_slice and ghost_layers parameters are not allowed
 
     Returns:
         AST, see `create_kernel`
     """
     assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs
-    assert staggered_field.index_dimensions in (1, 2), 'Staggered field must have one or two index dimensions'
+
+    if isinstance(assignments, AssignmentCollection):
+        subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments
+                                                       if not hasattr(a, 'lhs')
+                                                       or type(a.lhs) is not Field.Access
+                                                       or not FieldType.is_staggered(a.lhs.field)]
+        assignments = [a for a in assignments.main_assignments if hasattr(a, 'lhs')
+                       and type(a.lhs) is Field.Access
+                       and FieldType.is_staggered(a.lhs.field)]
+    else:
+        subexpressions = [a for a in assignments if not hasattr(a, 'lhs')
+                          or type(a.lhs) is not Field.Access
+                          or not FieldType.is_staggered(a.lhs.field)]
+        assignments = [a for a in assignments if hasattr(a, 'lhs')
+                       and type(a.lhs) is Field.Access
+                       and FieldType.is_staggered(a.lhs.field)]
+    if len(set([tuple(a.lhs.field.staggered_stencil) for a in assignments])) != 1:
+        raise ValueError("All assignments need to be made to staggered fields with the same stencil")
+    if len(set([a.lhs.field.shape for a in assignments])) != 1:
+        raise ValueError("All assignments need to be made to staggered fields with the same shape")
+
+    staggered_field = assignments[0].lhs.field
+    stencil = staggered_field.staggered_stencil
     dim = staggered_field.spatial_dimensions
+    shape = staggered_field.shape
 
     counters = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dim)]
-    conditions = [counters[i] < staggered_field.shape[i] - 1 for i in range(dim)]
-    assert len(expressions) == dim
-    if staggered_field.index_dimensions == 2:
-        assert all(len(sublist) == len(expressions[0]) for sublist in expressions), \
-            "If staggered field has two index dimensions expressions has to be a sequence of sequences of all the " \
-            "same length."
 
     final_assignments = []
-    last_conditional = None
 
-    def add(condition, dimensions, as_else_block=False):
-        nonlocal last_conditional
-        if staggered_field.index_dimensions == 1:
-            assignments = [Assignment(staggered_field(d), expressions[d]) for d in dimensions]
-            a_coll = AssignmentCollection(assignments, list(subexpressions))
-            a_coll = a_coll.new_filtered([staggered_field(d) for d in dimensions])
-        elif staggered_field.index_dimensions == 2:
-            assert staggered_field.has_fixed_index_shape
-            assignments = [Assignment(staggered_field(d, i), expr)
-                           for d in dimensions
-                           for i, expr in enumerate(expressions[d])]
-            a_coll = AssignmentCollection(assignments, list(subexpressions))
-            a_coll = a_coll.new_filtered([staggered_field(d, i) for i in range(staggered_field.index_shape[1])
-                                          for d in dimensions])
-        sp_assignments = [SympyAssignment(a.lhs, a.rhs) for a in a_coll.all_assignments]
-        if as_else_block and last_conditional:
-            new_cond = Conditional(condition, Block(sp_assignments))
-            last_conditional.false_block = Block([new_cond])
-            last_conditional = new_cond
-        else:
-            last_conditional = Conditional(condition, Block(sp_assignments))
-            final_assignments.append(last_conditional)
+    # find out whether any of the ghost layers is not needed
+    common_exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
+    for direction in stencil:
+        exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
+        for elementary_direction in direction:
+            exclusions.remove(inverse_direction_string(elementary_direction))
+        common_exclusions.intersection_update(exclusions)
+    ghost_layers = [[0, 0] for d in range(dim)]
+    for direction in common_exclusions:
+        direction = direction_string_to_offset(direction)
+        for d, s in enumerate(direction):
+            if s == 1:
+                ghost_layers[d][1] = 1
+            elif s == -1:
+                ghost_layers[d][0] = 1
 
-    if target == 'cpu' or not gpu_exclusive_conditions:
-        for d in range(dim):
-            cond = sp.And(*[conditions[i] for i in range(dim) if d != i])
-            add(cond, [d])
-    elif target == 'gpu':
-        full_conditions = [sp.And(*[conditions[i] for i in range(dim) if d != i]) for d in range(dim)]
-        for include in itertools.product(*[[1, 0]] * dim):
-            case_conditions = sp.And(*[c if value else sp.Not(c) for c, value in zip(full_conditions, include)])
-            dimensions_to_include = [i for i in range(dim) if include[i]]
-            if dimensions_to_include:
-                add(case_conditions, dimensions_to_include, True)
+    def condition(direction):
+        """exclude those staggered points that correspond to fluxes between ghost cells"""
+        exclusions = set(["E", "W", "N", "S", "T", "B"][:2 * dim])
+
+        for elementary_direction in direction:
+            exclusions.remove(inverse_direction_string(elementary_direction))
+        conditions = []
+        for e in exclusions:
+            if e in common_exclusions:
+                continue
+            offset = direction_string_to_offset(e)
+            for i, o in enumerate(offset):
+                if o == 1:
+                    conditions.append(counters[i] < shape[i] - 1)
+                elif o == -1:
+                    conditions.append(counters[i] > 0)
+        return sp.And(*conditions)
 
-    ghost_layers = [(1, 0)] * dim
+    if gpu_exclusive_conditions:
+        outer_assignment = None
+        conditions = {direction: condition(direction) for direction in stencil}
+        for num_conditions in range(len(stencil)):
+            for combination in combinations(conditions.values(), num_conditions):
+                for assignment in assignments:
+                    direction = stencil[assignment.lhs.index[0]]
+                    if conditions[direction] in combination:
+                        assignment = SympyAssignment(assignment.lhs, assignment.rhs)
+                        outer_assignment = Conditional(sp.And(*combination), Block([assignment]), outer_assignment)
 
-    blocking = kwargs.get('cpu_blocking', None)
-    if blocking:
-        del kwargs['cpu_blocking']
+        inner_assignment = []
+        for assignment in assignments:
+            direction = stencil[assignment.lhs.index[0]]
+            inner_assignment.append(SympyAssignment(assignment.lhs, assignment.rhs))
+        last_conditional = Conditional(sp.And(*[condition(d) for d in stencil]),
+                                       Block(inner_assignment), outer_assignment)
+        final_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
+                            [SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
+                            [last_conditional]
 
-    cpu_vectorize_info = kwargs.get('cpu_vectorize_info', None)
-    if cpu_vectorize_info:
-        del kwargs['cpu_vectorize_info']
-    openmp = kwargs.get('cpu_openmp', None)
-    if openmp:
-        del kwargs['cpu_openmp']
+        if target == 'cpu':
+            from pystencils.cpu import create_kernel as create_kernel_cpu
+            ast = create_kernel_cpu(final_assignments, ghost_layers=ghost_layers, **kwargs)
+        else:
+            ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
+        return ast
 
-    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, **kwargs)
+    for assignment in assignments:
+        direction = stencil[assignment.lhs.index[0]]
+        sp_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \
+                         [SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \
+                         [SympyAssignment(assignment.lhs, assignment.rhs)]
+        last_conditional = Conditional(condition(direction), Block(sp_assignments))
+        final_assignments.append(last_conditional)
 
-    if target == 'cpu':
-        remove_conditionals_in_staggered_kernel(ast)
-        move_constants_before_loop(ast)
-        omp_collapse = None
-        if blocking:
-            omp_collapse = loop_blocking(ast, blocking)
-        if openmp:
-            from pystencils.cpu import add_openmp
-            add_openmp(ast, num_threads=openmp, collapse=omp_collapse, assume_single_outer_loop=False)
-        if cpu_vectorize_info is True:
-            vectorize(ast)
-        elif isinstance(cpu_vectorize_info, dict):
-            vectorize(ast, **cpu_vectorize_info)
+    remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers])
+    prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional),
+                             move_constants_before_loop]
+    ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target,
+                        cpu_prepend_optimizations=prepend_optimizations, **kwargs)
     return ast
diff --git a/pystencils/stencil.py b/pystencils/stencil.py
index 9f70336f3a3b3043fad322bdc67e134d5e11906b..32b1283fd969deba2f45df1ed4526cb000817072 100644
--- a/pystencils/stencil.py
+++ b/pystencils/stencil.py
@@ -16,6 +16,11 @@ def inverse_direction(direction):
     return tuple([-i for i in direction])
 
 
+def inverse_direction_string(direction):
+    """Returns inverse of given direction string"""
+    return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))
+
+
 def is_valid(stencil, max_neighborhood=None):
     """
     Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index ea340e930631a159cf81cf9d56337b315e27c88b..762c36136cd7f3eb541a6075f4b24021813c82ad 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -630,13 +630,18 @@ def move_constants_before_loop(ast_node):
                     else:
                         target.insert_before(child, child_to_insert_before)
                 elif exists_already and exists_already.rhs == child.rhs:
-                    pass
+                    if target.args.index(exists_already) > target.args.index(child_to_insert_before):
+                        assert target.args.count(exists_already) == 1
+                        assert target.args.count(child_to_insert_before) == 1
+                        target.args.remove(exists_already)
+                        target.insert_before(exists_already, child_to_insert_before)
                 else:
                     # this variable already exists in outer block, but with different rhs
                     # -> symbol has to be renamed
                     assert isinstance(child.lhs, TypedSymbol)
                     new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
-                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
+                    target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
+                                         child_to_insert_before)
                     substitute_variables[child.lhs] = new_symbol
 
 
@@ -1064,15 +1069,19 @@ def insert_casts(node):
     return node.func(*args)
 
 
-def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
-    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
+def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
+    """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
+       first and last element"""
 
     all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
     assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
     inner_loop = all_inner_loops.pop()
 
     for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
-        cut_loop(loop, [loop.stop - 1])
+        if include_first:
+            cut_loop(loop, [loop.start + 1, loop.stop - 1])
+        else:
+            cut_loop(loop, [loop.stop - 1])
 
     simplify_conditionals(function_node.body, loop_counter_simplification=True)
     cleanup_blocks(function_node.body)
diff --git a/pystencils_tests/test_blocking_staggered.py b/pystencils_tests/test_blocking_staggered.py
index 76ec8abf0e3b76d30415571bb43e975f317e9a3f..a79efe7c4445faa9baeb8323383b382a42f2cf33 100644
--- a/pystencils_tests/test_blocking_staggered.py
+++ b/pystencils_tests/test_blocking_staggered.py
@@ -11,8 +11,9 @@ def test_blocking_staggered():
        f[0, 0, 0] - f[0, -1, 0],
        f[0, 0, 0] - f[0, 0, -1],
     ]
-    kernel = ps.create_staggered_kernel(stag, terms, cpu_blocking=(3, 16, 8)).compile()
-    reference_kernel = ps.create_staggered_kernel(stag, terms).compile()
+    assignments = [ps.Assignment(stag.staggered_access(d), terms[i]) for i, d in enumerate(stag.staggered_stencil)]
+    kernel = ps.create_staggered_kernel(assignments, cpu_blocking=(3, 16, 8)).compile()
+    reference_kernel = ps.create_staggered_kernel(assignments).compile()
     print(ps.show_code(kernel.ast))
 
     f_arr = np.random.rand(80, 33, 19)
diff --git a/pystencils_tests/test_loop_cutting.py b/pystencils_tests/test_loop_cutting.py
index 999e7b52a8b40111243c09aca1aa3fc1549a0cc2..cd89f37f6f365b4223e1463db68874f50e81c46d 100644
--- a/pystencils_tests/test_loop_cutting.py
+++ b/pystencils_tests/test_loop_cutting.py
@@ -55,7 +55,8 @@ def test_staggered_iteration():
         for d in range(dim):
             expressions.append(sum(f[o] for o in offsets_in_plane(d, 0, dim)) -
                                sum(f[o] for o in offsets_in_plane(d, -1, dim)))
-        func_optimized = create_staggered_kernel(s, expressions).compile()
+        assignments = [ps.Assignment(s.staggered_access(d), expressions[i]) for i, d in enumerate(s.staggered_stencil)]
+        func_optimized = create_staggered_kernel(assignments).compile()
         assert not func_optimized.ast.atoms(Conditional), "Loop cutting optimization did not work"
 
         func(f=f_arr, s=s_arr_ref)
@@ -111,8 +112,10 @@ def test_staggered_gpu():
     s = ps.fields("s({dim}): double[{dim}D]".format(dim=dim), field_type=FieldType.STAGGERED)
     expressions = [(f[0, 0] + f[-1, 0]) / 2,
                    (f[0, 0] + f[0, -1]) / 2]
-    kernel_ast = ps.create_staggered_kernel(s, expressions, target='gpu', gpu_exclusive_conditions=True)
+    assignments = [ps.Assignment(s.staggered_access(d), expressions[i]) for i, d in enumerate(s.staggered_stencil)]
+    kernel_ast = ps.create_staggered_kernel(assignments, target='gpu', gpu_exclusive_conditions=True)
     assert len(kernel_ast.atoms(Conditional)) == 4
 
-    kernel_ast = ps.create_staggered_kernel(s, expressions, target='gpu', gpu_exclusive_conditions=False)
+    assignments = [ps.Assignment(s.staggered_access(d), expressions[i]) for i, d in enumerate(s.staggered_stencil)]
+    kernel_ast = ps.create_staggered_kernel(assignments, target='gpu', gpu_exclusive_conditions=False)
     assert len(kernel_ast.atoms(Conditional)) == 3
diff --git a/pystencils_tests/test_staggered_kernel.py b/pystencils_tests/test_staggered_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db538bf8f83d6779cb7dfd612fa9928acc23965
--- /dev/null
+++ b/pystencils_tests/test_staggered_kernel.py
@@ -0,0 +1,85 @@
+import pystencils as ps
+import numpy as np
+import sympy as sp
+
+
+class TestStaggeredDiffusion:
+    def _run(self, num_neighbors):
+        L = (40, 40)
+        D = 0.066
+        dt = 1
+        T = 100
+
+        dh = ps.create_data_handling(L, periodicity=True, default_target='cpu')
+
+        c = dh.add_array('c', values_per_cell=1)
+        j = dh.add_array('j', values_per_cell=num_neighbors, field_type=ps.FieldType.STAGGERED_FLUX)
+
+        x_staggered = - c[-1, 0] + c[0, 0]
+        y_staggered = - c[0, -1] + c[0, 0]
+        xy_staggered = - c[-1, -1] + c[0, 0]
+        xY_staggered = - c[-1, 1] + c[0, 0]
+
+        jj = j.staggered_access
+        divergence = -1 * D / (1 + sp.sqrt(2) if j.index_shape[0] == 4 else 1) * \
+            sum([jj(d) / sp.Matrix(ps.stencil.direction_string_to_offset(d)).norm() for d in j.staggered_stencil
+                + [ps.stencil.inverse_direction_string(d) for d in j.staggered_stencil]])
+
+        update = [ps.Assignment(c.center, c.center + dt * divergence)]
+        flux = [ps.Assignment(j.staggered_access("W"), x_staggered),
+                ps.Assignment(j.staggered_access("S"), y_staggered)]
+        if j.index_shape[0] == 4:
+            flux += [ps.Assignment(j.staggered_access("SW"), xy_staggered),
+                     ps.Assignment(j.staggered_access("NW"), xY_staggered)]
+
+        staggered_kernel = ps.create_staggered_kernel(flux, target=dh.default_target).compile()
+        div_kernel = ps.create_kernel(update, target=dh.default_target).compile()
+
+        def time_loop(steps):
+            sync = dh.synchronization_function([c.name])
+            dh.all_to_gpu()
+            for i in range(steps):
+                sync()
+                dh.run_kernel(staggered_kernel)
+                dh.run_kernel(div_kernel)
+            dh.all_to_cpu()
+
+        def init():
+            dh.fill(c.name, np.nan, ghost_layers=True, inner_ghost_layers=True)
+            dh.fill(c.name, 0)
+            dh.fill(j.name, np.nan, ghost_layers=True, inner_ghost_layers=True)
+            dh.cpu_arrays[c.name][L[0] // 2:L[0] // 2 + 2, L[1] // 2:L[1] // 2 + 2] = 1.0
+
+        init()
+        time_loop(T)
+
+        reference = np.empty(L)
+        for x in range(L[0]):
+            for y in range(L[1]):
+                r = np.array([x, y]) - L[0] / 2 + 0.5
+                reference[x, y] = (4 * np.pi * D * T)**(-dh.dim / 2) * np.exp(-np.dot(r, r) / (4 * D * T)) * (2**dh.dim)
+
+        assert np.abs(dh.gather_array(c.name) - reference).max() < 5e-4
+
+    def test_diffusion_2(self):
+        self._run(2)
+
+    def test_diffusion_4(self):
+        self._run(4)
+
+
+def test_staggered_subexpressions():
+    dh = ps.create_data_handling((10, 10), periodicity=True, default_target='cpu')
+    j = dh.add_array('j', values_per_cell=2, field_type=ps.FieldType.STAGGERED)
+    c = sp.symbols("c")
+    assignments = [ps.Assignment(j.staggered_access("W"), c),
+                   ps.Assignment(c, 1)]
+    ps.create_staggered_kernel(assignments, target=dh.default_target).compile()
+
+
+def test_staggered_loop_cutting():
+    dh = ps.create_data_handling((4, 4), periodicity=True, default_target='cpu')
+    j = dh.add_array('j', values_per_cell=4, field_type=ps.FieldType.STAGGERED)
+    assignments = [ps.Assignment(j.staggered_access("SW"), 1)]
+    ast = ps.create_staggered_kernel(assignments, target=dh.default_target)
+    assert not ast.atoms(ps.astnodes.Conditional)