diff --git a/examples/dem.py b/examples/dem.py
index f2a60d5b71728172186c87c9fc85cefa4784d002..8ac9d8266a6a3625d025755b92e7d6f658609902 100644
--- a/examples/dem.py
+++ b/examples/dem.py
@@ -107,7 +107,7 @@ psim.add_contact_property('is_sticking', pairs.int32(), 0)
 psim.add_contact_property('tangential_spring_displacement', pairs.vector(), [0.0, 0.0, 0.0])
 psim.add_contact_property('impact_velocity_magnitude', pairs.double(), 0.0)
 
-psim.read_particle_data("data/fluidized_bed.input", ['mass', 'position', 'linear_velocity'])
+psim.read_particle_data("data/fluidized_bed.input", ['mass', 'position', 'linear_velocity'], pairs.sphere())
 psim.build_neighbor_lists(cutoff_radius + skin)
 psim.vtk_output(f"output/test_{target}")
 psim.compute(linear_spring_dashpot, cutoff_radius, symbols={'dt': dt})
diff --git a/examples/lj.py b/examples/lj.py
index dab50c7457308f902899efbd25b257eaa341186f..a9aa2ffe80844ad4b07233e776421211d58cd541 100644
--- a/examples/lj.py
+++ b/examples/lj.py
@@ -34,7 +34,7 @@ psim.add_property('force', pairs.vector(), vol=True)
 psim.add_feature('type', ntypes)
 psim.add_feature_property('type', 'epsilon', pairs.double(), [sigma for i in range(ntypes * ntypes)])
 psim.add_feature_property('type', 'sigma6', pairs.double(), [epsilon for i in range(ntypes * ntypes)])
-psim.read_particle_data("data/minimd_setup_32x32x32.input", ['type', 'mass', 'position', 'linear_velocity', 'particle_flags'])
+psim.read_particle_data("data/minimd_setup_32x32x32.input", ['type', 'mass', 'position', 'linear_velocity', 'flags'], pairs.sphere())
 psim.build_neighbor_lists(cutoff_radius + skin)
 psim.vtk_output(f"output/test_{target}")
 psim.compute(lj, cutoff_radius)
diff --git a/runtime/read_from_file.hpp b/runtime/read_from_file.hpp
index 4f98c9390f69d0515e81a93de5481092d37f921d..95f007ecec239fa6979a06937ea6b27d7924f9ab 100644
--- a/runtime/read_from_file.hpp
+++ b/runtime/read_from_file.hpp
@@ -30,9 +30,10 @@ void read_grid_data(PairsSimulation *ps, const char *filename, double *grid_buff
     }
 }
 
-size_t read_particle_data(PairsSimulation *ps, const char *filename, const property_t properties[], size_t nprops) {
+size_t read_particle_data(PairsSimulation *ps, const char *filename, const property_t properties[], size_t nprops, int shape_id) {
     std::ifstream in_file(filename, std::ifstream::in);
     std::string line;
+    auto shape_ptr = ps->getAsIntegerProperty(ps->getPropertyByName("shape"));
     size_t n = 0;
 
     if(in_file.is_open()) {
@@ -77,7 +78,9 @@ size_t read_particle_data(PairsSimulation *ps, const char *filename, const prope
                 i++;
             }
 
-            n += (within_domain) ? 1 : 0;
+            if(within_domain) {
+                shape_ptr(n++) = shape_id;
+            }
         }
 
         in_file.close();
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index 2807d54642f7f47fc2be74edab2b30b5a5320089..abaeba14a10b7ba59d5c4e0a9055d4edc2c20c62 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -1,6 +1,7 @@
 from pairs.ir.types import Types
 from pairs.code_gen.cgen import CGen
 from pairs.code_gen.target import Target
+from pairs.sim.shapes import Shapes
 from pairs.sim.simulation import Simulation
 
 
@@ -21,3 +22,9 @@ def double():
 
 def vector():
     return Types.Vector
+
+def sphere():
+    return Shapes.Sphere
+
+def halfspace():
+    return Shapes.Halfspace
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index 8ddf259f4b856318062732d6687784e03991bb7f..9bb616b0ab38a450c24b40b4d3a266ab6fc83ff4 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -3,12 +3,12 @@ import math
 from pairs.ir.assign import Assign
 from pairs.ir.ast_term import ASTTerm
 from pairs.ir.atomic import AtomicAdd
-from pairs.ir.scalars import ScalarOp
 from pairs.ir.block import pairs_device_block, pairs_host_block
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.cast import Cast
-from pairs.ir.loops import For, ParticleFor
+from pairs.ir.loops import For, ParticleFor, While
 from pairs.ir.math import Ceil
+from pairs.ir.scalars import ScalarOp
 from pairs.ir.types import Types
 from pairs.ir.utils import Print
 from pairs.sim.flags import Flags
@@ -31,6 +31,7 @@ class CellLists:
         self.dim_ncells         =   self.sim.add_static_array('dim_cells', self.sim.ndims(), Types.Int32)
         self.cell_particles     =   self.sim.add_array('cell_particles', [self.ncells_capacity, self.cell_capacity], Types.Int32)
         self.cell_sizes         =   self.sim.add_array('cell_sizes', self.ncells_capacity, Types.Int32)
+        self.nshapes            =   self.sim.add_array('nshapes', [self.ncells_capacity, self.sim.max_shapes()], Types.Int32)
         self.stencil            =   self.sim.add_array('stencil', self.nstencil_max, Types.Int32)
         self.particle_cell      =   self.sim.add_array('particle_cell', self.sim.particle_capacity, Types.Int32)
 
@@ -104,3 +105,35 @@ class BuildCellLists(Lowerable):
                 index_in_cell = AtomicAdd(sim, cl.cell_sizes[flat_index], 1)
                 Assign(sim, cl.particle_cell[i], flat_index)
                 Assign(sim, cl.cell_particles[flat_index][index_in_cell], i)
+
+
+class PartitionCellLists(Lowerable):
+    def __init__(self, sim, cell_lists):
+        super().__init__(sim)
+        self.cell_lists = cell_lists
+
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name("partition_cell_lists")
+        cell_particles = self.cell_lists.cell_particles
+
+        for cell in For(self.sim, 0, self.cell_lists.ncells):
+            start = self.sim.add_temp_var(0)
+            end = self.sim.add_temp_var(0)
+
+            for shape in For(self.sim, 0, self.sim.max_shapes()):
+                Assign(self.sim, end, self.cell_lists.cell_sizes[cell] - 1)
+
+                for _ in While(self.sim, start < end):
+                    particle = cell_particles[cell][start]
+
+                    for unmatch in Branch(self.sim, ScalarOp.neq(self.sim.particle_shape[particle], shape)):
+                        if unmatch:
+                            Assign(self.sim, cell_particles[cell][start], cell_particles[cell][end])
+                            Assign(self.sim, cell_particles[cell][end], particle)
+                            Assign(self.sim, end, end - 1)
+
+                        else:
+                            Assign(self.sim, start, start + 1)
+
+                    Assign(self.sim, self.cell_lists.nshapes[cell][shape], start)
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 55b3d6c0514eed7661d7d3356b7aa2cc7daf328e..bb4b2612009d12276cf1f938add21fc68c544a4b 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -44,7 +44,7 @@ class Comm:
 
     @pairs_inline
     def borders(self):
-        prop_list = [self.sim.property(p) for p in ['mass', 'position', 'particle_flags']]
+        prop_list = [self.sim.property(p) for p in ['mass', 'position', 'flags']]
         Assign(self.sim, self.nsend_all, 0)
         Assign(self.sim, self.sim.nghost, 0)
 
@@ -59,7 +59,7 @@ class Comm:
 
     @pairs_inline
     def exchange(self):
-        prop_list = [self.sim.property(p) for p in ['mass', 'position', 'linear_velocity', 'particle_flags']]
+        prop_list = [self.sim.property(p) for p in ['mass', 'position', 'linear_velocity', 'shape', 'flags']]
         for step in range(self.dom_part.number_of_steps()):
             Assign(self.sim, self.nsend_all, 0)
             Assign(self.sim, self.sim.nghost, 0)
diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py
index ec2497b27c7746bface233c7610a24cf9dc5e99d..c87171a5e3d7621de26a1b39dffe7aad89620230 100644
--- a/src/pairs/sim/read_from_file.py
+++ b/src/pairs/sim/read_from_file.py
@@ -8,12 +8,13 @@ from pairs.sim.lowerable import Lowerable
 
 
 class ReadParticleData(Lowerable):
-    def __init__(self, sim, filename, items):
+    def __init__(self, sim, filename, items, shape_id):
         super().__init__(sim)
         self.filename = filename
         self.attrs = ParticleAttributeList(sim, items)
         self.grid = MutableGrid(sim, sim.ndims())
-        self.grid_buffer = self.sim.add_static_array("grid_buffer", [self.sim.ndims() * 2], Types.Double)
+        self.grid_buffer = sim.add_static_array("grid_buffer", [sim.ndims() * 2], Types.Double)
+        self.shape_id = shape_id
 
     @pairs_inline
     def lower(self):
@@ -26,7 +27,7 @@ class ReadParticleData(Lowerable):
         grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.sim.ndims())]
         Call_Void(self.sim, "pairs->initDomain", [param for delim in grid_array for param in delim]),
         Call_Void(self.sim, "pairs->fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom])
-        Assign(self.sim, self.sim.nlocal, Call_Int(self.sim, "pairs::read_particle_data", [self.filename, self.attrs, self.attrs.length()]))
+        Assign(self.sim, self.sim.nlocal, Call_Int(self.sim, "pairs::read_particle_data", [self.filename, self.attrs, self.attrs.length(), self.shape_id]))
 
 
 class ReadFeatureData(Lowerable):
diff --git a/src/pairs/sim/shapes.py b/src/pairs/sim/shapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..8359da7f2b6a1f7fd138d100a3eb77035d70e4ed
--- /dev/null
+++ b/src/pairs/sim/shapes.py
@@ -0,0 +1,3 @@
+class Shapes:
+    Sphere      =   0
+    Halfspace   =   1
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 7b70de08e69896f9686d785c7405aca078c6f8b9..56e30c330b2c9009589a2c02309eabb1fee3954f 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -13,7 +13,7 @@ from pairs.ir.variables import Variables
 from pairs.graph.graphviz import ASTGraph
 from pairs.mapping.funcs import compute
 from pairs.sim.arrays import DeclareArrays
-from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil
+from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists
 from pairs.sim.comm import Comm
 from pairs.sim.contact_history import ContactHistory, BuildContactHistory
 from pairs.sim.domain_partitioning import DimensionRanges
@@ -46,7 +46,8 @@ class Simulation:
         self.nlocal = self.add_var('nlocal', Types.Int32)
         self.nghost = self.add_var('nghost', Types.Int32)
         self.resizes = self.add_array('resizes', 3, Types.Int32, arr_sync=False)
-        self.particle_flags = self.add_property('particle_flags', Types.Int32, 0)
+        self.particle_shape = self.add_property('shape', Types.Int32, 0)
+        self.particle_flags = self.add_property('flags', Types.Int32, 0)
         self.grid = None
         self.cell_lists = None
         self.neighbor_lists = None
@@ -70,6 +71,9 @@ class Simulation:
         self._target = None
         self._dom_part = DimensionRanges(self)
 
+    def max_shapes(self):
+        return 2
+
     def add_module(self, module):
         assert isinstance(module, Module), "add_module(): Given parameter is not of type Module!"
         if module.name not in [m.name for m in self.module_list]:
@@ -176,9 +180,9 @@ class Simulation:
         lattice = ParticleLattice(self, grid, spacing, props, positions)
         self.setups.add_statement(lattice)
 
-    def read_particle_data(self, filename, prop_names):
+    def read_particle_data(self, filename, prop_names, shape_id):
         props = [self.property(prop_name) for prop_name in prop_names]
-        read_object = ReadParticleData(self, filename, props)
+        read_object = ReadParticleData(self, filename, props, shape_id)
         self.setups.add_statement(read_object)
         self.grid = read_object.grid
 
@@ -271,6 +275,7 @@ class Simulation:
             (comm.exchange(), 20),
             (comm.borders(), comm.synchronize(), 20),
             (BuildCellLists(self, self.cell_lists), 20),
+            (PartitionCellLists(self, self.cell_lists), 20),
             (BuildNeighborLists(self, self.neighbor_lists), 20),
         ]