diff --git a/examples/dem.py b/examples/dem.py
index 1bb5ea9e1ccf528bfea267f99995c871c7dcf483..9aec3a121ac26a2fe5c441c3fe812b59cc37db9d 100644
--- a/examples/dem.py
+++ b/examples/dem.py
@@ -137,6 +137,11 @@ psim = pairs.simulation(
     double_prec=True,
     use_contact_history=True)
 
+if target == 'gpu':
+    psim.target(pairs.target_gpu())
+else:
+    psim.target(pairs.target_cpu())
+
 psim.add_position('position')
 psim.add_property('mass', pairs.real(), 1.0)
 psim.add_property('linear_velocity', pairs.vector())
@@ -196,10 +201,4 @@ psim.compute(linear_spring_dashpot,
                       'collisionTime_SI': collisionTime_SI})
 
 psim.compute(euler, symbols={'dt': dt_SI})
-
-if target == 'gpu':
-    psim.target(pairs.target_gpu())
-else:
-    psim.target(pairs.target_cpu())
-
 psim.generate()
diff --git a/src/pairs/sim/contact_history.py b/src/pairs/sim/contact_history.py
index 8eccab15a81b953eed663a65738adce2e6f79dd1..9a1a94d76e41ce8bd045ace9df45fcb45e0e8838 100644
--- a/src/pairs/sim/contact_history.py
+++ b/src/pairs/sim/contact_history.py
@@ -1,7 +1,7 @@
 from pairs.ir.assign import Assign
 from pairs.ir.block import pairs_device_block
 from pairs.ir.branches import Branch, Filter
-from pairs.ir.loops import ParticleFor, For
+from pairs.ir.loops import ParticleFor, For, While
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.types import Types
 from pairs.ir.utils import Print
@@ -95,12 +95,33 @@ class ClearUnusedContactHistory(Lowerable):
 
     @pairs_device_block
     def lower(self):
+        contact_lists = self.contact_history.contact_lists
         contact_used = self.contact_history.contact_used
         num_contacts = self.contact_history.num_contacts
         self.sim.module_name("clear_unused_contact_history")
 
-        for i in ParticleFor(self.sim):
-            for c in For(self.sim, 0, num_contacts[i]):
-                for _ in Filter(self.sim, ScalarOp.cmp(contact_used[i][c], 0)):
-                    for contact_prop in self.sim.contact_properties:
-                        Assign(self.sim, contact_prop[i, c], contact_prop.default())
+        if self.sim.neighbor_lists is None:
+            for i in ParticleFor(self.sim):
+                c = self.sim.add_temp_var(0)
+                for _ in While(self.sim, c < num_contacts[i]):
+                    for unused in Branch(self.sim, ScalarOp.cmp(contact_used[i][c], 0)):
+                        if unused:
+                            last_contact = num_contacts[i] - 1
+                            for _ in Filter(self.sim, last_contact > 0):
+                                for contact_prop in self.sim.contact_properties:
+                                    Assign(self.sim, contact_prop[i, c], contact_prop[i, last_contact])
+
+                                Assign(self.sim, contact_lists[i][c], contact_lists[i][last_contact])
+                                Assign(self.sim, contact_used[i][c], contact_used[i][last_contact])
+
+                            Assign(self.sim, num_contacts[i], num_contacts[i] - 1)
+
+                        else:
+                            Assign(self.sim, c, c + 1)
+
+        else:
+            for i in ParticleFor(self.sim):
+                for c in For(self.sim, 0, num_contacts[i]):
+                    for _ in Filter(self.sim, ScalarOp.cmp(contact_used[i][c], 0)):
+                        for contact_prop in self.sim.contact_properties:
+                            Assign(self.sim, contact_prop[i, c], contact_prop.default())