diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index e40f54ffa17a8a1129d5bb1a5154747041ac0447..dae1bea3967079ec5e1b14d530a8979cad241218 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -19,6 +19,8 @@ class Comm:
         self.nsend         = sim.add_array('nsend', [max_neigh_ranks], Types.Int32)
         self.send_buffer   = sim.add_array('send_buffer', [self.send_capacity, max_buffer_elems], Types.Double)
         self.send_map      = sim.add_array('send_map', [self.send_capacity], Types.Int32)
+        self.exchg_flag    = sim.add_array('exchg_flag', [sim.particle_capacity], Types.Int32)
+        self.exchg_copy_to = sim.add_array('exchg_copy_to', [self.send_capacity], Types.Int32)
         self.send_mult     = sim.add_array('send_mult', [self.send_capacity, sim.ndims()], Types.Int32)
         self.nrecv         = sim.add_array('nrecv', [max_neigh_ranks], Types.Int32)
         self.recv_buffer   = sim.add_array('recv_buffer', [self.recv_capacity, max_buffer_elems], Types.Double)
@@ -27,10 +29,12 @@ class Comm:
 
 
 class DetermineGhostParticles(Lowerable):
-    def __init__(self, sim, comm, dom_part):
+    def __init__(self, sim, comm, dom_part, step, spacing):
         super().__init__(sim)
         self.comm = comm
         self.dom_part = dom_part
+        self.step = step
+        self.spacing = spacing
 
     @pairs_device_block
     def lower(self):
@@ -42,7 +46,7 @@ class DetermineGhostParticles(Lowerable):
 
         nb_rank_id = 0
         nsend_all.set(0)
-        for i, _, pbc in self.dom_part.ghost_particles(self.sim.position(), self.sim.cell_spacing()):
+        for i, _, pbc in self.dom_part.ghost_particles(step, self.sim.position(), self.spacing):
             n = AtomicAdd(self.sim, nsend_all, 1)
             send_map[n].set(i)
             for d in self.sim.ndims():
@@ -118,7 +122,67 @@ class UnpackGhostParticles(Lowerable):
                     p_offset += 1
 
 
-class ExchangeParticles(Lowerable):
+class RemoveExchangedParticles_part1(Lowerable):
+    def __init__(self, sim, comm, prop_list):
+        super().__init__(sim)
+        self.comm = comm
+
+    @pairs_host_block
+    def lower(self):
+        send_pos = sim.add_temp_var(grid.nparticles)
+        sim.module_name("remove_exchanged_particles_pt1")
+
+        for i in For(self.sim, self.comm.nsend):
+            for is_local in Branch(self.sim, self.comm.send_map[i] < self.sim.nlocal - self.comm.nsend):
+                if is_local:
+                    for _ in While(self.sim, BinOp.eq(self.comm.exchg_flags[send_pos], 1)):
+                        send_pos.assign(send_pos - 1)
+
+                    self.comm.exchg_copy_to[i].set(send_pos)
+                    send_pos.assign(send_pos - 1)
+
+                else:
+                    self.comm.exchg_copy_to[i].set(-1)
+
+
+class RemoveExchangedParticles_part2(Lowerable):
+    def __init__(self, sim, comm, prop_list):
+        super().__init__(sim)
+        self.comm = comm
+        self.prop_list = prop_list
+
+    @pairs_device_block
+    def lower(self):
+        sim.module_name("remove_exchanged_particles_pt2")
+
+        for i in ParticleFor(self.sim):
+            src = self.comm.exchg_copy_to[i]
+            for _ in Filter(self.sim, src > 0):
+                dst = self.comm.send_map[i]
+                for p in self.prop_list:
+                    if p.type() == Types.Vector:
+                        for d in self.sim.ndims():
+                            p[dst][d].set(p[src][d])
+
+                    else:
+                        p[dst].set(p[src])
+
+        self.sim.nlocal.set(self.sim.nlocal - self.comm.nsend)
+
+
+class Synchronize(Lowerable):
+    def __init__(self, sim, comm):
+        super().__init__(sim)
+        self.comm = comm
+
+
+class Borders(Lowerable):
+    def __init__(self, sim, comm):
+        super().__init__(sim)
+        self.comm = comm
+
+
+class Exchange(Lowerable):
     def __init__(self, sim, comm):
         super().__init__(sim)
         self.comm = comm
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index 9a3076debd0b8436f67b9080c28c94cf181688c0..2606911863cc6509b0d57ee9d8a77f662dce3f7a 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -14,15 +14,16 @@ class DimensionRanges:
         self.pbc            = sim.add_static_array('pbc', [sim.ndims() * 2], Types.Int32)
         self.subdom         = sim.add_static_array('subdom', [sim.ndims() * 2], Types.Int32)
 
-    def ghost_particles(self, position, offset=0.0):
-        for dim in range(0, self.sim.ndims()):
-            nall = self.sim.nlocal + self.sim.comm.nghost
-            for i in For(sim, 0, nall):
-                j = dim * 2
-                for _ in Filter(self.sim, position < self.subdom[j] + offset):
-                    yield i, self.neighbor_ranks[j], [0 if d != dim else self.pbc[j] for d in self.sim.ndims()]
+    def number_of_steps(self):
+        return self.sim.ndims()
 
-            for i in For(sim, 0, nall):
-                j = dim * 2 + 1
-                for _ in Filter(self.sim, position > self.subdom[j] - offset):
-                    yield i, self.neighbor_ranks[j], [0 if d != dim else self.pbc[j] for d in self.sim.ndims()]
+    def ghost_particles(self, step, position, offset=0.0):
+        for i in For(sim, 0, self.sim.nlocal + self.sim.comm.nghost):
+            j = step * 2 + 0
+            for _ in Filter(self.sim, position < self.subdom[j] + offset):
+                yield i, self.neighbor_ranks[j], [0 if d != step else self.pbc[j] for d in self.sim.ndims()]
+
+        for i in For(sim, 0, self.sim.nlocal + self.sim.comm.nghost):
+            j = step * 2 + 1
+            for _ in Filter(self.sim, position > self.subdom[j] - offset):
+                yield i, self.neighbor_ranks[j], [0 if d != step else self.pbc[j] for d in self.sim.ndims()]