diff --git a/examples/dem.py b/examples/dem.py
index 0276de9b5bd338b6f05795026a65b749a345d308..2c9dcf3c5bac4294e832f5a85acaec2499ba8b01 100644
--- a/examples/dem.py
+++ b/examples/dem.py
@@ -97,8 +97,9 @@ if target != 'cpu' and target != 'gpu':
     print(f"Invalid target, use {cmd} <cpu/gpu>")
 
 # Config file parameters
-#domainSize_SI = [0.8, 0.015, 0.2]
-domainSize_SI = [0.6, 0.6, 0.2] # node base
+domainSize_SI = [0.8, 0.015, 0.2]
+#domainSize_SI = [0.4, 0.4, 0.2] # node base
+#domainSize_SI = [0.6, 0.6, 0.2] # node base
 #domainSize_SI = [0.8, 0.8, 0.2] # node base
 diameter_SI = 0.0029
 gravity_SI = 9.81
@@ -140,6 +141,7 @@ if target == 'gpu':
     psim.target(pairs.target_gpu())
 else:
     psim.target(pairs.target_cpu())
+    #psim.target(pairs.target_cpu(parallel=True))
 
 psim.add_position('position')
 psim.add_property('mass', pairs.real(), 1.0)
@@ -163,9 +165,9 @@ psim.set_domain([0.0, 0.0, 0.0, domainSize_SI[0], domainSize_SI[1], domainSize_S
 psim.set_domain_partitioner(pairs.regular_domain_partitioner_xy())
 psim.pbc([True, True, False])
 psim.read_particle_data(
-    #"data/spheres.input",
+    "data/spheres.input",
     #"data/spheres_4x4x2.input",
-    "data/spheres_6x6x2.input",
+    #"data/spheres_6x6x2.input",
     #"data/spheres_8x8x2.input",
     ['uid', 'type', 'mass', 'radius', 'position', 'linear_velocity', 'flags'],
     pairs.sphere())
@@ -185,7 +187,7 @@ psim.setup(update_mass_and_inertia, {'densityParticle_SI': densityParticle_SI,
                                      'infinity': math.inf })
 
 #psim.compute_half()
-psim.build_cell_lists(linkedCellWidth)
+psim.build_cell_lists(linkedCellWidth, store_neighbors_per_cell=True)
 #psim.vtk_output(f"output/dem_{target}", frequency=visSpacing)
 
 psim.compute(gravity,
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index c64e81f455c26ada6b76a9c2a2191d651f702983..016ca1a96892e41502d97237a03edafec02d36be 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -37,6 +37,11 @@ class CellLists:
         self.stencil            =   self.sim.add_array('stencil', self.nstencil_max, Types.Int32)
         self.particle_cell      =   self.sim.add_array('particle_cell', self.sim.particle_capacity, Types.Int32)
 
+        if sim._store_neighbors_per_cell:
+            self.cell_neigh_capacity = self.sim.add_var('cell_neigh_capacity', Types.Int32, 80)
+            self.cell_nneighs = self.sim.add_array('cell_nneighs', [self.ncells_capacity, self.sim.max_shapes()], Types.Int32)
+            self.cell_neighbors = self.sim.add_array('cell_neighbors', [self.ncells_capacity, self.cell_neigh_capacity], Types.Int32)
+
 
 class BuildCellListsStencil(Lowerable):
     def __init__(self, sim, cell_lists):
@@ -164,3 +169,38 @@ class PartitionCellLists(Lowerable):
                         else:
                             Assign(self.sim, start, start + 1)
                             Assign(self.sim, self.cell_lists.nshapes[cell][shape], start - shape_start)
+
+
+class BuildCellNeighborLists(Lowerable):
+    def __init__(self, sim, cell_lists):
+        super().__init__(sim)
+        self.cell_lists = cell_lists
+
+    @pairs_device_block
+    def lower(self):
+        ncells = self.cell_lists.ncells
+        nshapes = self.cell_lists.nshapes
+        cell_particles = self.cell_lists.cell_particles
+        cell_nneighs = self.cell_lists.cell_nneighs
+        cell_neighbors = self.cell_lists.cell_neighbors
+        self.sim.module_name("build_cell_neighbor_lists")
+        self.sim.check_resize(self.cell_lists.cell_neigh_capacity, cell_nneighs)
+
+        for cell in For(self.sim, 0, ncells):
+            for shape in range(self.sim.max_shapes()):
+                Assign(self.sim, cell_nneighs[cell][shape], 0)
+
+                for disp in For(self.sim, -1, self.cell_lists.nstencil):
+                    neigh_cell = Select(self.sim, disp < 0, 0, cell + self.cell_lists.stencil[disp])
+
+                    for _ in Filter(self.sim, ScalarOp.or_op(disp < 0,
+                                                             ScalarOp.and_op(neigh_cell > 0,
+                                                                             neigh_cell < ncells))):
+
+                        start = sum([nshapes[neigh_cell][s] for s in range(shape)], 0)
+                        for cell_particle in For(self.sim, start, start + nshapes[neigh_cell][shape]):
+                            particle = cell_particles[neigh_cell][cell_particle]
+                            neighs_start = sum([cell_nneighs[cell][s] for s in range(shape)], 0)
+                            numneighs = cell_nneighs[cell][shape]
+                            Assign(self.sim, cell_neighbors[cell][neighs_start + numneighs], particle)
+                            Assign(self.sim, cell_nneighs[cell][shape], numneighs + 1)
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index edcb65b59da19e62f88fb70eecce4ea303ce5d38..4d3baa6d325faf614c781fd3e4f1d959a8f857ad 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -61,28 +61,57 @@ class NeighborFor:
             particle_shape = self.sim.particle_shape
             nshapes = self.cell_lists.nshapes
 
-            for shape in self.shapes:
-                for disp in For(self.sim, -1, self.cell_lists.nstencil):
-                    neigh_cell = \
-                        Select(self.sim, disp < 0, 0, particle_cell[self.particle] + stencil[disp])
+            if self.sim._store_neighbors_per_cell:
+                cell_nneighs = self.cell_lists.cell_nneighs
+                cell_neighbors = self.cell_lists.cell_neighbors
+
+                for shape in self.shapes:
+                    start = sum([cell_nneighs[cell][s] for s in range(shape)], 0)
+                    # FIXME: Without the inline, the 'cell' expression is being generated after
+                    # its usage in the loop upper limit
+                    cell = ScalarOp.inline(particle_cell[self.particle])
+                    for k in For(self.sim, start, start + cell_nneighs[cell][shape]):
+                        particle_id = cell_neighbors[cell][k]
+
+                        if self.sim._compute_half:
+                            shape_id = self.sim.get_shape_id(shape)
+                            shape_cond = particle_shape[particle_id] > shape_id
+                            condition = ScalarOp.or_op(shape_cond, self.particle < particle_id)
 
-                    for _ in Filter(self.sim, ScalarOp.or_op(disp < 0,
-                                                             ScalarOp.and_op(neigh_cell > 0,
-                                                                             neigh_cell < ncells))):
+                        else:
+                            condition = ScalarOp.neq(particle_id, self.particle)
 
-                        start = sum([nshapes[neigh_cell][s] for s in range(shape)], 0)
-                        for cell_particle in For(self.sim, start, start + nshapes[neigh_cell][shape]):
-                            particle_id = cell_particles[neigh_cell][cell_particle]
+                        for _ in Filter(self.sim, condition):
+                            yield Neighbor(self.sim, k, None, particle_id, shape)
 
-                            if self.sim._compute_half:
-                                shape_cond = particle_shape[particle_id] > self.sim.get_shape_id(shape)
-                                condition = ScalarOp.or_op(shape_cond, self.particle < particle_id)
+            else:
+                for shape in self.shapes:
+                    for disp in For(self.sim, -1, self.cell_lists.nstencil):
+                        neigh_cell = \
+                            Select(self.sim, disp < 0, 0, particle_cell[self.particle] + stencil[disp])
 
-                            else:
-                                condition = ScalarOp.neq(particle_id, self.particle)
+                        for _ in Filter(self.sim, ScalarOp.or_op(disp < 0,
+                                                                 ScalarOp.and_op(neigh_cell > 0,
+                                                                                 neigh_cell < ncells))):
+
+                            start = sum([nshapes[neigh_cell][s] for s in range(shape)], 0)
+                            for cell_particle in For(self.sim,
+                                                     start,
+                                                     start + nshapes[neigh_cell][shape]):
+
+                                particle_id = cell_particles[neigh_cell][cell_particle]
+
+                                if self.sim._compute_half:
+                                    shape_id = self.sim.get_shape_id(shape)
+                                    shape_cond = particle_shape[particle_id] > shape_id
+                                    condition = ScalarOp.or_op(shape_cond, self.particle < particle_id)
+
+                                else:
+                                    condition = ScalarOp.neq(particle_id, self.particle)
 
-                            for _ in Filter(self.sim, condition):
-                                yield Neighbor(self.sim, cell_particle, neigh_cell, particle_id, shape)
+                                for _ in Filter(self.sim, condition):
+                                    yield Neighbor(
+                                        self.sim, cell_particle, neigh_cell, particle_id, shape)
 
         else:
             neighborlists = self.neighbor_lists.neighborlists
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 5bb602ccacfcb0877d922ab8e0cb06851886e44e..8b10fe053220f2d7240e8b6898a2750cd16606d6 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -12,7 +12,7 @@ from pairs.ir.variables import Variables
 #from pairs.graph.graphviz import ASTGraph
 from pairs.mapping.funcs import compute, setup
 from pairs.sim.arrays import DeclareArrays
-from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists
+from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists, BuildCellNeighborLists
 from pairs.sim.comm import Comm
 from pairs.sim.contact_history import ContactHistory, BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus
 from pairs.sim.copper_fcc_lattice import CopperFCCLattice
@@ -64,6 +64,7 @@ class Simulation:
         self.particle_flags = self.add_property('flags', Types.Int32, 0)
         self.grid = None
         self.cell_lists = None
+        self._store_neighbors_per_cell = False
         self.neighbor_lists = None
         self.scope = []
         self.nested_count = 0
@@ -240,11 +241,15 @@ class Simulation:
     def copper_fcc_lattice(self, nx, ny, nz, rho, temperature, ntypes):
         self.setups.add_statement(CopperFCCLattice(self, nx, ny, nz, rho, temperature, ntypes))
 
-    def build_cell_lists(self, spacing):
+    def build_cell_lists(self, spacing, store_neighbors_per_cell=False):
+        self._store_neighbors_per_cell = store_neighbors_per_cell
         self.cell_lists = CellLists(self, self._dom_part, spacing, spacing)
         return self.cell_lists
 
     def build_neighbor_lists(self, spacing):
+        assert self._store_neighbors_per_cell is False, \
+            "Using neighbor-lists with store_neighbors_per_cell option is invalid."
+
         self.cell_lists = CellLists(self, self._dom_part, spacing, spacing)
         self.neighbor_lists = NeighborLists(self, self.cell_lists)
         return self.neighbor_lists
@@ -380,6 +385,10 @@ class Simulation:
             (PartitionCellLists(self, self.cell_lists), every_reneighbor_params)
         ]
 
+        if self._store_neighbors_per_cell:
+            timestep_procedures.append(
+                (BuildCellNeighborLists(self, self.cell_lists), every_reneighbor_params))
+
         if self.neighbor_lists is not None:
             timestep_procedures.append(
                 (BuildNeighborLists(self, self.neighbor_lists), every_reneighbor_params))