diff --git a/examples/benchmarks/sd_static.cpp b/examples/benchmarks/sd_static.cpp
index 9da037e7f498115855764aede894f43d69f58486..aad480137a1ece7601baf1e170fdcbe4a7bd0188 100644
--- a/examples/benchmarks/sd_static.cpp
+++ b/examples/benchmarks/sd_static.cpp
@@ -34,14 +34,17 @@ int main(int argc, char **argv) {
                                         diameter, diameter, diameter,
                                         initial_velocity, density, 1);
     
+    pairs_sim->update_mass_and_inertia(); 
     
     // Cell width is here smaller than sphere diameter only for convenience to have everything aligned on a grid, but this 
     // doesn't affect the interactions computed. All spheres are on cell centers and are in contact with 6 neighbors. 
     double cell_width = particle_spacing;
-    pairs_sim->setup_cells(cell_width, cell_width, cell_width, cell_width);
+
+    pairs_sim->setCellWidth(cell_width, cell_width, cell_width);
+    pairs_sim->setInteractionRadius(cell_width);
+    pairs_sim->updateDomain();
     
     // Inertia update is required for euler updates to be valid (but particles remain stationary)
-    pairs_sim->update_mass_and_inertia(); 
     double dt = 0.001;  // Arbitrary
     
     int rank = pairs_sim->rank();
@@ -59,10 +62,9 @@ int main(int argc, char **argv) {
 
     for (int t=0; t<num_timesteps; ++t){
         if ((t%print_interval==0) && rank==0) std::cout << "Timestep: " << t << std::endl;
-        pairs_sim->communicate(t);
-        pairs_sim->update_cells(t);
         pairs_sim->spring_dashpot();
         pairs_sim->euler(dt);
+        pairs_sim->reneighbor();
     }
 
     auto end = std::chrono::high_resolution_clock::now();
@@ -81,6 +83,8 @@ int main(int argc, char **argv) {
         double pups = global_nparticles * num_timesteps / total_runtime;    // particle updates per second
         std::cout << "PUPS: " << pups << std::endl;
     }
+
+    pairs::log_timers(pairs_runtime);
     
     // pairs::vtk_write_data(pairs_runtime, "output/local_spheres", 0, pairs_sim->nlocal(), 0);
     // pairs::vtk_write_data(pairs_runtime, "output/ghost_spheres", pairs_sim->nlocal(), pairs_sim->size(), 0);
diff --git a/examples/benchmarks/sd_static_triangular.cpp b/examples/benchmarks/sd_static_triangular.cpp
index 05323a2aca0a7c2c5eabd2bd5e2f121dd908911c..284b8aa8c6777ad0cd64b0c2fd7658ad37077284 100644
--- a/examples/benchmarks/sd_static_triangular.cpp
+++ b/examples/benchmarks/sd_static_triangular.cpp
@@ -15,7 +15,7 @@ void print_global_stats(std::string name, Type value, MPI_Datatype mpi_type, MPI
 
     if(rank == 0){
         std::sort(all_values.begin(), all_values.end());
-        Type sum = std::accumulate(all_values.begin(), all_values.end(), 0.0);
+        Type sum = std::accumulate(all_values.begin(), all_values.end(), Type());
         double avg = static_cast<double>(sum)/world_size;
 
         // Standard deviation ------------------
@@ -28,9 +28,9 @@ void print_global_stats(std::string name, Type value, MPI_Datatype mpi_type, MPI
         // Median ------------------------------
         double median = 0.0;
         if(world_size%2 == 0){
-            median = (all_values[world_size/2 - 1] + all_values[world_size/2]) / 2.0;   
+            median = static_cast<double>(all_values[world_size/2 - 1] + all_values[world_size/2]) / 2.0;   
         } else{
-            median = all_values[world_size/2];
+            median = static_cast<double>(all_values[world_size/2]);
         }
 
         std::cout << "-----------------------------------" << std::endl;
@@ -90,37 +90,47 @@ int main(int argc, char **argv) {
     
     double initial_velocity = 0.0;  // Stationary 
     double density = 1000;          // Arbitrary
-    double lower_tirangular = true;
+    bool lower_tirangular = true;
     
     pairs::dem_sc_grid(pairs_runtime,   domain_size[0], domain_size[1], domain_size[2],
                                         particle_spacing, 
                                         diameter, diameter, diameter,
                                         initial_velocity, density, 1, lower_tirangular);
     
-    double cell_width = diameter;
-    pairs_sim->setup_cells(cell_width, cell_width, cell_width, cell_width);
-
     // Inertia update is required for euler updates to be valid (but particles remain stationary)
     pairs_sim->update_mass_and_inertia(); 
+
+    double cell_width = diameter;
+    pairs_sim->setCellWidth(cell_width, cell_width, cell_width);
+    pairs_sim->setInteractionRadius(cell_width);
+
     double dt = 0.001;  // Arbitrary
-    int print_interval = (num_timesteps >= 5) ? (num_timesteps / 5) : 1;
+    uint64_t print_interval = (num_timesteps >= 5) ? (num_timesteps / 5) : 1;
     
-    // Update domain so stats become available (does rebalancing if rebalance is ture)
-    pairs_sim->update_domain();
+    // Rebalance
+    pairs_sim->updateDomain();
 
     // Stats
     // ------------------------------------------------------------------------------
     int rank = pairs_sim->rank();
     int world_size = pairs_runtime->getDomainPartitioner()->getWorldSize();
 
-    int num_neigh_ranks = pairs_runtime->getDomainPartitioner()->getNumberOfNeighborRanks();
+    int num_local_aabbs = pairs_runtime->getDomainPartitioner()->getNumberOfLocalAABBs();
     int num_neigh_aabbs = pairs_runtime->getDomainPartitioner()->getNumberOfNeighborAABBs();
+    int num_neigh_ranks = pairs_runtime->getDomainPartitioner()->getNumberOfNeighborRanks();
     uint64_t nlocal = pairs_sim->nlocal();
     uint64_t nghost = pairs_sim->nghost();
-    print_global_stats("NUM_NEIGH_RANKS", num_neigh_ranks, MPI_INT, MPI_COMM_WORLD);
-    print_global_stats("NUM_NEIGH_AABBS", num_neigh_aabbs, MPI_INT, MPI_COMM_WORLD);
+
+    std::cout << "rank (" << rank << "): \t nlocal = " << nlocal << " nghost = " << nghost << 
+         " local_aabbs = " << num_local_aabbs << 
+         " neigh_aabbs = " << num_neigh_aabbs << 
+         " neigh_ranks = " << num_neigh_ranks << std::endl;
+
     print_global_stats("NLOCAL", nlocal, MPI_UINT64_T, MPI_COMM_WORLD);
     print_global_stats("NGHOST", nghost, MPI_UINT64_T, MPI_COMM_WORLD);
+    print_global_stats("NUM_LOCAL_AABBS", num_local_aabbs, MPI_INT, MPI_COMM_WORLD);
+    print_global_stats("NUM_NEIGH_AABBS", num_neigh_aabbs, MPI_INT, MPI_COMM_WORLD);
+    print_global_stats("NUM_NEIGH_RANKS", num_neigh_ranks, MPI_INT, MPI_COMM_WORLD);
 
     if(rank==0){
         std::cout << "NUM_PROC: " << world_size << std::endl;
@@ -133,12 +143,11 @@ int main(int argc, char **argv) {
     MPI_Barrier(MPI_COMM_WORLD);
     auto start = std::chrono::high_resolution_clock::now();
 
-    for (int t=0; t<num_timesteps; ++t){
+    for (uint64_t t=0; t<num_timesteps; ++t){
         if ((t%print_interval==0) && rank==0) std::cout << "Timestep: " << t << std::endl;
-        pairs_sim->communicate(t);
-        pairs_sim->update_cells(t);
         pairs_sim->spring_dashpot();
         pairs_sim->euler(dt);
+        pairs_sim->reneighbor();
     }
 
     auto end = std::chrono::high_resolution_clock::now();
@@ -152,13 +161,14 @@ int main(int argc, char **argv) {
         std::cout << "TOTAL_RUNTIME: " << total_runtime << std::endl;
         std::cout << "GLOBAL_NPARTICLES: " << global_nparticles << std::endl;
         
-        double pups = global_nparticles * num_timesteps / total_runtime;    // particle updates per second
+        double pups = static_cast<double>(global_nparticles * num_timesteps) / total_runtime;    // particle updates per second
         std::cout << "PUPS: " << pups << std::endl;
     }
     
     // pairs::vtk_write_subdom(pairs_runtime, "output/sd_subdom", 0);
     // pairs::vtk_write_data(pairs_runtime, "output/sd_local", 0, pairs_sim->nlocal(), 0);
     // pairs::vtk_write_data(pairs_runtime, "output/sd_ghost", pairs_sim->nlocal(), pairs_sim->size(), 0);
+    pairs::log_timers(pairs_runtime);
 
     pairs_sim->end();
 }
\ No newline at end of file
diff --git a/examples/benchmarks/spring_dashpot.py b/examples/benchmarks/spring_dashpot.py
index bfc7aed159450cd8c3f91b47c94dbe2541aaf501..b1380dcce58957dc04c0a2407ed8e1369ea8d619 100644
--- a/examples/benchmarks/spring_dashpot.py
+++ b/examples/benchmarks/spring_dashpot.py
@@ -86,7 +86,7 @@ psim.add_feature_property('type', 'friction', pairs.real())
 # psim.set_domain_partitioner(pairs.regular_domain_partitioner())
 psim.set_domain_partitioner(pairs.block_forest())
 psim.pbc([True, True, True])
-psim.build_cell_lists(use_halo_cells=False)
+psim.build_cell_lists(use_halo_cells=False, optimize_halo_paddings=False)
 
 psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf })
 psim.compute(spring_dashpot, profile=False)
diff --git a/examples/benchmarks/spring_dashpot_no_pbc.py b/examples/benchmarks/spring_dashpot_no_pbc.py
index d0d43f6417b74f48bd2a85ca784f7d875f2f2a17..ed613d33ca594db756023743664e3fd7ee43db41 100644
--- a/examples/benchmarks/spring_dashpot_no_pbc.py
+++ b/examples/benchmarks/spring_dashpot_no_pbc.py
@@ -86,7 +86,7 @@ psim.add_feature_property('type', 'friction', pairs.real())
 # psim.set_domain_partitioner(pairs.regular_domain_partitioner())
 psim.set_domain_partitioner(pairs.block_forest())
 psim.pbc([False, False, False])
-psim.build_cell_lists(use_halo_cells=False)
+psim.build_cell_lists(use_halo_cells=False, optimize_halo_paddings=False)
 
 psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf })
 psim.compute(spring_dashpot, profile=False)
diff --git a/examples/modular/force_reduction.cpp b/examples/modular/force_reduction.cpp
index 7e770a58704c2ff928a395f9111094557406af43..b03e50deb45e6ce55d44e4f19ee41f5d19345cd6 100644
--- a/examples/modular/force_reduction.cpp
+++ b/examples/modular/force_reduction.cpp
@@ -15,9 +15,11 @@ int main(int argc, char **argv) {
     // Create bodies
     pairs::id_t pUid = pairs::create_sphere(pairs_runtime, 0.0499,   0.0499,   0.07,   0.5, 0.5, 0 ,   1000, 0.0045, 0, 0);
     
-    // setup_cells after creating all bodies
-    pairs_sim->setup_cells();
     pairs_sim->update_mass_and_inertia();
+    
+    // updateDomain after creating all bodies
+    pairs_sim->updateDomain();
+    ac->update();
 
     // Track particle
     //-------------------------------------------------------------------------------------------
@@ -30,11 +32,6 @@ int main(int argc, char **argv) {
     if (pUid != ac->getInvalidUid()){
         std::cout<< "Particle " << pUid << " will be tracked by rank " << pairs_sim->rank() << std::endl;
     }
-
-    // Communicate particles (exchange/ghost)
-    //-------------------------------------------------------------------------------------------
-    pairs_sim->communicate(0);
-    ac->update();
         
     // Helper lambdas for demo
     //-------------------------------------------------------------------------------------------
@@ -91,16 +88,14 @@ int main(int argc, char **argv) {
         
         // Do computations
         //-------------------------------------------------------------------------------------------
-        pairs_sim->update_cells(t); 
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot();
         pairs_sim->euler(5e-5);        
         //-------------------------------------------------------------------------------------------
 
-        std::cout << "---- reverse_comm and reduce ----" << std::endl;
-        // reverse_comm() communicates data from ghost particles back to their owner ranks using
-        // information from the previous time that communicate() was called 
-        pairs_sim->reverse_comm();  
+        std::cout << "Reverse communicate and reduce." << std::endl;
+        // Communicate ghost particle data back to their owner ranks and reduce
+        pairs_sim->reverseCommunicate();  
 
         // Get the reduced force on the owner rank
         //-------------------------------------------------------------------------------------------
@@ -115,9 +110,9 @@ int main(int argc, char **argv) {
                         << force_sum[0] << ", " << force_sum[1] << ", " << force_sum[2] << ")" <<  std::endl;
         }
         
-        // Usual communication 
+        // Forward communication 
         //-------------------------------------------------------------------------------------------
-        pairs_sim->communicate(t);
+        pairs_sim->reneighbor();
         ac->update();
     }
 
diff --git a/examples/modular/sd_1.cpp b/examples/modular/sd_1.cpp
index d7e0dcae798ee4b8e46013bd9e872ad97ea76eb4..a705c2538af87405a97abbcb79f9f3b5d48b920d 100644
--- a/examples/modular/sd_1.cpp
+++ b/examples/modular/sd_1.cpp
@@ -21,9 +21,12 @@ int main(int argc, char **argv) {
     pairs::create_sphere(pairs_runtime, 0.6, 0.6, 0.7,      -2, -2, 0,  1000, 0.05, 0, 0);
     pairs::create_sphere(pairs_runtime, 0.4, 0.4, 0.68,    2, 2, 0,    1000, 0.05, 0, 0);
 
-    pairs_sim->setup_cells(0.1, 0.1, 0.1, 0.1);
     pairs_sim->update_mass_and_inertia();
 
+    pairs_sim->setCellWidth(0.1, 0.1, 0.1);
+    pairs_sim->setInteractionRadius(0.1);
+    pairs_sim->updateDomain();
+
     int num_timesteps = 2000;
     int vtk_freq = 20;
     double dt = 1e-3;
@@ -31,13 +34,10 @@ int main(int argc, char **argv) {
     for (int t=0; t<num_timesteps; ++t){
         if ((t%500==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
 
-        pairs_sim->communicate(t);
-        
-        pairs_sim->update_cells(t); 
-
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
         pairs_sim->euler(dt); 
+        pairs_sim->reneighbor();
 
         pairs::vtk_write_data(pairs_runtime, "output/sd_1_local", 0, pairs_sim->nlocal(), t, vtk_freq);
         pairs::vtk_write_data(pairs_runtime, "output/sd_1_ghost", pairs_sim->nlocal(), pairs_sim->size(), t, vtk_freq);
diff --git a/examples/modular/sd_2.cpp b/examples/modular/sd_2.cpp
index 3ae126949f934aa2d1eff1eb3d0d895a26afec9d..cc5eb453c92e5727cc716a5628c29d4787c1a170 100644
--- a/examples/modular/sd_2.cpp
+++ b/examples/modular/sd_2.cpp
@@ -42,9 +42,12 @@ int main(int argc, char **argv) {
     pairs::create_sphere(pairs_runtime, 0.6, 0.6, 0.7,      -2, -2, 0,  1000, 0.05, 0, 0);
     pairs::create_sphere(pairs_runtime, 0.4, 0.4, 0.68,    2, 2, 0,    1000, 0.05, 0, 0);
 
-    pairs_sim->setup_cells(0.1, 0.1, 0.1, 0.1);
     pairs_sim->update_mass_and_inertia();
 
+    pairs_sim->setCellWidth(0.1, 0.1, 0.1);
+    pairs_sim->setInteractionRadius(0.1);
+    pairs_sim->updateDomain();
+
     int num_timesteps = 2000;
     int vtk_freq = 20;
     double dt = 1e-3;
@@ -52,13 +55,10 @@ int main(int argc, char **argv) {
     for (int t=0; t<num_timesteps; ++t){
         if ((t%500==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
 
-        pairs_sim->communicate(t);
-        
-        pairs_sim->update_cells(t); 
-
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
         pairs_sim->euler(dt); 
+        pairs_sim->reneighbor();
 
         pairs::vtk_write_data(pairs_runtime, "output/sd_2_local", 0, pairs_sim->nlocal(), t, vtk_freq);
         pairs::vtk_write_data(pairs_runtime, "output/sd_2_ghost", pairs_sim->nlocal(), pairs_sim->size(), t, vtk_freq);
diff --git a/examples/modular/sd_3_CPU.cpp b/examples/modular/sd_3_CPU.cpp
index a8e4a93bc5e816eb8eee900b7ef1d6769e7ec206..44766dd1c6950723fc4d96801cf964dd2407a957 100644
--- a/examples/modular/sd_3_CPU.cpp
+++ b/examples/modular/sd_3_CPU.cpp
@@ -31,11 +31,13 @@ int main(int argc, char **argv) {
     MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
 
     auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();};
-
-    pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1);
+    
     pairs_sim->update_mass_and_inertia();
 
-    pairs_sim->communicate(0);
+    pairs_sim->setCellWidth(0.1, 0.1, 0.1);
+    pairs_sim->setInteractionRadius(0.1);
+    pairs_sim->updateDomain();
+    
 
     int num_timesteps = 2000;
     int vtk_freq = 20;
@@ -57,7 +59,6 @@ int main(int argc, char **argv) {
 
         // Calculate forces
         //-------------------------------------------------------------------------------------------
-        pairs_sim->update_cells(t);
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
 
@@ -85,7 +86,7 @@ int main(int argc, char **argv) {
 
         // Communicate
         //-------------------------------------------------------------------------------------------
-        pairs_sim->communicate(t);
+        pairs_sim->reneighbor();
 
         pairs::vtk_write_data(pairs_runtime, "output/sd_3_CPU_local", 0, ac->nlocal(), t, vtk_freq);
         pairs::vtk_write_data(pairs_runtime, "output/sd_3_CPU_ghost", ac->nlocal(), ac->size(), t, vtk_freq);
diff --git a/examples/modular/sd_3_GPU.cu b/examples/modular/sd_3_GPU.cu
index bf95b2037f366cb417320f22a34c24591a1f7ebc..5b1789618c0877663cc077b4ff966604677158f6 100644
--- a/examples/modular/sd_3_GPU.cu
+++ b/examples/modular/sd_3_GPU.cu
@@ -64,11 +64,13 @@ int main(int argc, char **argv) {
     MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
 
     auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();};
-
-    pairs_sim->setup_cells(0.1, 0.1, 0.1, 0.1);
+    
     pairs_sim->update_mass_and_inertia();
 
-    pairs_sim->communicate(0);
+    pairs_sim->setCellWidth(0.1, 0.1, 0.1);
+    pairs_sim->setInteractionRadius(0.1);
+    pairs_sim->updateDomain();
+
     // PairsAccessor requires an update when particles are communicated 
     ac->update();
 
@@ -104,7 +106,6 @@ int main(int argc, char **argv) {
 
         // Calculate forces
         //-------------------------------------------------------------------------------------------
-        pairs_sim->update_cells(t);
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
 
@@ -138,10 +139,10 @@ int main(int argc, char **argv) {
         //-------------------------------------------------------------------------------------------
         pairs_sim->euler(dt);
 
-        // Communicate
+        // Reneighbor
         //-------------------------------------------------------------------------------------------
-        pairs_sim->communicate(t);
-        // PairsAccessor requires an update when particles are communicated
+        pairs_sim->reneighbor();
+        // PairsAccessor requires an update when particles are reneighbored
         ac->update();
 
         pairs::vtk_write_data(pairs_runtime, "output/dem_sd_local", 0, ac->nlocal(), t, vtk_freq);
diff --git a/examples/modular/sd_4.cpp b/examples/modular/sd_4.cpp
index e83e66d1d4e29f2d281b2673ace6047cba7a3eb2..791a30415998357e50b7165968a9e4248a616e91 100644
--- a/examples/modular/sd_4.cpp
+++ b/examples/modular/sd_4.cpp
@@ -58,34 +58,36 @@ int main(int argc, char **argv) {
     double sphere_spacing = 0.4;
     pairs::dem_sc_grid(pairs_runtime, 10, 10, 15,  sphere_spacing, diameter_min, diameter_min, diameter_max,    2,      100,    2);
     
-    double lcw = diameter_max * 1.01;       // Linked-cell width
+    double cell_width = diameter_max;
     double interaction_radius = diameter_max;
-    pairs_sim->setup_cells(lcw, lcw, lcw, interaction_radius);
 
     pairs_sim->update_mass_and_inertia();
 
+    pairs_sim->setCellWidth(cell_width, cell_width, cell_width);
+    pairs_sim->setInteractionRadius(interaction_radius);
+    pairs_sim->updateDomain();
+
     int num_timesteps = 4000;
     int vtk_freq = 20;
     int rebalance_freq = 200;
     double dt = 1e-3;
 
     pairs::vtk_write_subdom(pairs_runtime, "output/subdom_init", 0);
-
     
     for (int t=0; t<num_timesteps; ++t){
         if ((t % vtk_freq==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
         
-        if (t % rebalance_freq == 0){ 
-            pairs_sim->update_domain();
-        }
-        
-        pairs_sim->update_cells(t); 
         
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
         pairs_sim->euler(dt); 
         
-        pairs_sim->communicate(t);
+        if (t % rebalance_freq == 0){ 
+            pairs_sim->updateDomain();
+        }
+        else {
+            pairs_sim->reneighbor();
+        }
 
         if (t % vtk_freq==0){
             pairs::vtk_write_subdom(pairs_runtime, "output/subdom", t);
diff --git a/examples/modular/sphere_box_global.cpp b/examples/modular/sphere_box_global.cpp
index 3c307383ee14f35d5ad0c205c4924df471b4fd95..44390624e809ebfaf1abcd156222ea38d2f37816 100644
--- a/examples/modular/sphere_box_global.cpp
+++ b/examples/modular/sphere_box_global.cpp
@@ -60,11 +60,13 @@ int main(int argc, char **argv) {
     pairs::create_sphere(pairs_runtime, 15, 20, 15,     0, 4, 0,                50, 4, 0,       pairs::flags::GLOBAL);
     pairs::create_sphere(pairs_runtime, 15, 25, 4,      0, 0, 0,                50, 4, 0,       pairs::flags::GLOBAL | pairs::flags::FIXED); 
     
-    // Use the diameter of small particles to set up the cell list
-    double lcw = radius * 2;
-    pairs_sim->setup_cells(lcw, lcw, lcw, lcw);
     pairs_sim->update_mass_and_inertia();
-    pairs_sim->communicate(0);
+    
+    // Use the diameter of small particles to set up the cell list
+    double cell_width = radius * 2;
+    pairs_sim->setCellWidth(cell_width, cell_width, cell_width);
+    pairs_sim->setInteractionRadius(cell_width);
+    pairs_sim->updateDomain();
 
     int num_timesteps = 20000; 
     int vtk_freq = 100;
@@ -76,20 +78,20 @@ int main(int argc, char **argv) {
     for (int t=0; t<num_timesteps; ++t){
         if ((t % vtk_freq==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
         
-        if (t % rebalance_freq == 0){ 
-            pairs_sim->update_domain();
-        }
-        
-        pairs_sim->update_cells(t); 
-
         pairs_sim->gravity(); 
         
         // All global and local interactions are contained within the 'spring_dashpot' module
         // You have the option to call spring_dashpot before or after 'gravity' or any other force-update module
         pairs_sim->spring_dashpot();     
 
-        pairs_sim->euler(dt); 
-        pairs_sim->communicate(t);
+        pairs_sim->euler(dt);
+
+        if (t % rebalance_freq == 0){ 
+            pairs_sim->updateDomain();
+        }
+        else {
+            pairs_sim->reneighbor();
+        }
         
         if (t % vtk_freq==0){
             pairs::vtk_with_rotation(pairs_runtime, pairs::Shapes::Box, "output/local_boxes", 0, pairs_sim->nlocal(), t);
diff --git a/runtime/domain/ParticleDataHandling.hpp b/runtime/domain/ParticleDataHandling.hpp
index c13737c2aa395402ccf4308346b3e57eced057dc..58d4971186382df680c817412a39985b26c84714 100644
--- a/runtime/domain/ParticleDataHandling.hpp
+++ b/runtime/domain/ParticleDataHandling.hpp
@@ -163,7 +163,7 @@ public:
     }
 
     void serializeImpl(Block *const block, const BlockDataID&, mpi::SendBuffer& buffer, const uint_t child, bool check_child) {
-        auto ptr = buffer.allocate<uint_t>();
+        auto ptr = buffer.allocate<int>();
         double aabb_check[6];
 
         if(check_child) {
@@ -275,21 +275,21 @@ public:
         ps->setTrackedVariableAsInteger("nlocal", nlocal - nserialized);
         ps->setTrackedVariableAsInteger("nghost", 0);
         
-        *ptr = (uint_t) nserialized;
+        *ptr = (int) nserialized;
     }
 
     void deserializeImpl(IBlock *const, const BlockDataID&, mpi::RecvBuffer& buffer) {
         int nlocal = ps->getTrackedVariableAsInteger("nlocal");
-        int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
         real_t real_tmp;
         int int_tmp;
-        uint_t nrecv;
+        int nrecv;
         uint64_t uint64_tmp;
 
         buffer >> nrecv;
         
         // TODO: Check if there is enough particle capacity for the new particles, when there is not,
         // all properties and arrays which have particle_capacity as one of their dimensions must be reallocated
+        // int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
         // PAIRS_ASSERT(nlocal + nrecv < particle_capacity);
 
         for(int i = 0; i < nrecv; ++i) {
diff --git a/runtime/domain/block_forest.cpp b/runtime/domain/block_forest.cpp
index 03275362e44bd4b07cf6c4d293e49aacd2fb25cd..42543cbc026eba931d2fb993b36766b2650f6e40 100644
--- a/runtime/domain/block_forest.cpp
+++ b/runtime/domain/block_forest.cpp
@@ -31,29 +31,31 @@ BlockForest::BlockForest(
 }
 
 BlockForest::BlockForest(PairsRuntime *ps_, const std::shared_ptr<walberla::blockforest::BlockForest> &bf) :
-        forest(bf),
         DomainPartitioner(bf->getDomain().xMin(), bf->getDomain().xMax(),
                         bf->getDomain().yMin(), bf->getDomain().yMax(),
-                        bf->getDomain().zMin(), bf->getDomain().zMax()), 
-        ps(ps_) {
+                        bf->getDomain().zMin(), bf->getDomain().zMax()),
+        ps(ps_), forest(bf)
+
+         {
             subdom = new real_t[ndims * 2];
             mpiManager = walberla::mpi::MPIManager::instance();
             world_size = mpiManager->numProcesses();
             rank = mpiManager->rank();
             this->info = make_shared<walberla::blockforest::InfoCollection>();
+            this->updateNeighborhood();
 }
 
 void BlockForest::updateNeighborhood() {
     std::map<int, std::vector<walberla::math::AABB>> neighborhood;
     std::map<int, std::vector<walberla::BlockID>> blocks_pushed;
     auto me = mpiManager->rank();
-    this->nranks = 0;
-    this->total_aabbs = 0;
+    this->num_neigh_ranks = 0;
+    this->total_num_neigh_aabbs = 0;
 
     ranks.clear();
-    naabbs.clear();
+    num_neigh_aabbs.clear();
     aabb_offsets.clear();
-    aabbs.clear();
+    neigh_aabbs.clear();
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
         for(uint neigh = 0; neigh < block->getNeighborhoodSize(); ++neigh) {
@@ -65,7 +67,7 @@ void BlockForest::updateNeighborhood() {
                 walberla::math::AABB neighbor_aabb = block->getNeighborAABB(neigh);
                 auto begin = blocks_pushed[neighbor_rank].begin();
                 auto end = blocks_pushed[neighbor_rank].end();
-                
+
                 if(find_if(begin, end, [neighbor_id](const auto &bp) { return bp == neighbor_id; }) == end) {
                     neighborhood[neighbor_rank].push_back(neighbor_aabb);
                     blocks_pushed[neighbor_rank].push_back(neighbor_id);
@@ -75,35 +77,41 @@ void BlockForest::updateNeighborhood() {
     }
 
     for(auto& nbh: neighborhood) {
-        auto rank = nbh.first;
+        auto neigh_rank = nbh.first;
         auto aabb_list = nbh.second;
-        ranks.push_back((int) rank);
-        aabb_offsets.push_back(this->total_aabbs);
-        naabbs.push_back((int) aabb_list.size());
+        ranks.push_back((int) neigh_rank);
+        aabb_offsets.push_back(this->total_num_neigh_aabbs);
+        num_neigh_aabbs.push_back((int) aabb_list.size());
 
         for(auto &aabb: aabb_list) {
-            aabbs.push_back(aabb.xMin());
-            aabbs.push_back(aabb.xMax());
-            aabbs.push_back(aabb.yMin());
-            aabbs.push_back(aabb.yMax());
-            aabbs.push_back(aabb.zMin());
-            aabbs.push_back(aabb.zMax());
-            this->total_aabbs++;
+            neigh_aabbs.push_back(aabb.xMin());
+            neigh_aabbs.push_back(aabb.xMax());
+            neigh_aabbs.push_back(aabb.yMin());
+            neigh_aabbs.push_back(aabb.yMax());
+            neigh_aabbs.push_back(aabb.zMin());
+            neigh_aabbs.push_back(aabb.zMax());
+            this->total_num_neigh_aabbs++;
         }
 
-        this->nranks++;
+        this->num_neigh_ranks++;
     }
+
+    this->is_neighborhood_up_to_date = true;
 }
 
 void BlockForest::copyRuntimeArray(const std::string& name, void *dest, const int size) {
-    void *src = name.compare("ranks") == 0          ? static_cast<void *>(ranks.data()) :
-                name.compare("naabbs") == 0         ? static_cast<void *>(naabbs.data()) :
-                name.compare("aabb_offsets") == 0   ? static_cast<void *>(aabb_offsets.data()) :
-                name.compare("aabbs") == 0          ? static_cast<void *>(aabbs.data()) :
-                name.compare("subdom") == 0         ? static_cast<void *>(subdom) : nullptr;
+    void *src = name.compare("ranks") == 0              ? static_cast<void *>(ranks.data()) :
+                name.compare("num_neigh_aabbs") == 0    ? static_cast<void *>(num_neigh_aabbs.data()) :
+                name.compare("aabb_offsets") == 0       ? static_cast<void *>(aabb_offsets.data()) :
+                name.compare("neigh_aabbs") == 0        ? static_cast<void *>(neigh_aabbs.data()) :
+                name.compare("local_aabbs") == 0        ? static_cast<void *>(local_aabbs.data()) :
+                name.compare("non_empty_local_aabbs") == 0 ? static_cast<void *>(non_empty_local_aabbs.data()) :
+                name.compare("has_non_empty_aabb_in_neighborhood_of_rank") == 0 
+                    ? static_cast<void *>(has_non_empty_aabb_in_neighborhood_of_rank.data()) :
+                name.compare("subdom") == 0             ? static_cast<void *>(subdom) : nullptr;
 
     PAIRS_ASSERT(src != nullptr);
-    bool is_real = (name.compare("aabbs") == 0) || (name.compare("subdom") == 0);
+    bool is_real = (name.compare("neigh_aabbs") == 0) || (name.compare("subdom") == 0)  || (name.compare("local_aabbs") == 0);
     int tsize = is_real ? sizeof(real_t) : sizeof(int);
     std::memcpy(dest, src, size * tsize);
 }
@@ -113,7 +121,7 @@ void BlockForest::updateWeights() {
 
     info->clear();
 
-    int sum_block_locals = 0;
+    size_t sum_block_locals = 0;
     // Compute the weights for my blocks and their children
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
@@ -139,13 +147,18 @@ void BlockForest::updateWeights() {
         }
     }
     
-    int non_globals = ps->getTrackedVariableAsInteger("nlocal") - UniqueID::getNumGlobals();
+    size_t non_globals = ps->getTrackedVariableAsInteger("nlocal") - UniqueID::getNumGlobals();
     
     if(sum_block_locals!=non_globals){
         std::cout << "Warning: " << non_globals - sum_block_locals << " particles in rank " << rank << 
         " may get lost in the next rebalancing." << std::endl;
     }
 
+// Neighbor weights are currently not used, but they may be useful in the future for other optimizations.
+// Note: Neighborhood cannot be built based on weights, because an empty neighbor in one timestep may become
+// non-empty in the next timestep. Therefore, all neighbors, irrespective of their weights, must be added to neighborhood.
+
+/*
     // Send the weights of my blocks and their children to the neighbors of my blocks
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
@@ -177,12 +190,12 @@ void BlockForest::updateWeights() {
             info->insert(val);
         }
     }
+*/
 }
 
 walberla::Vector3<int> BlockForest::getBlockConfig() {
     real_t area[3];
     real_t best_surf = 0.0;
-    int ndims = 3;
     int d = 0;
     int nranks[3] = {1, 1, 1};
 
@@ -233,10 +246,30 @@ void BlockForest::setBoundingBox() {
     for (int i=0; i<6; ++i) subdom[i] = 0.0;
     if (forest->empty()) return;
 
-    auto aabb_union = forest->begin()->getAABB();
+    walberla::math::AABB aabb_union;
+    int block_idx = 0;
+    bool init_aabb = true;      // Avoid merge with default zero-initialized aabb
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
-        aabb_union.merge(block->getAABB());
+        if(non_empty_local_aabbs[block_idx]){
+            auto aabb = block->getAABB();
+            if(init_aabb){
+                aabb_union = aabb;
+                init_aabb = false;
+            }
+            else{
+                aabb_union.merge(aabb);
+            }
+
+            if(rank==15){
+                vtk_write_aabb(this->ps, "output/non_empty_aabbs_", block_idx, 
+                    aabb.xMin(), aabb.xMax(), 
+                    aabb.yMin(), aabb.yMax(), 
+                    aabb.zMin(), aabb.zMax());
+            }
+
+        }
+        ++block_idx;
     }
 
     subdom[0] = aabb_union.xMin();
@@ -263,6 +296,7 @@ void BlockForest::initialize(int *argc, char ***argv) {
     forest = walberla::blockforest::createBlockForest(domain, block_config, pbc, world_size, ref_level);
 
     this->info = make_shared<walberla::blockforest::InfoCollection>();
+    this->updateNeighborhood();
 
     if (rank==0) {
         std::cout << "Domain Partitioner: BlockForest" << std::endl;
@@ -273,29 +307,73 @@ void BlockForest::initialize(int *argc, char ***argv) {
     }
 }
 
-void BlockForest::update() {
-    if(balance_workload) {
+void BlockForest::updateLocal() {
+    has_non_empty_aabb_in_neighborhood_of_rank.clear();
+    non_empty_local_aabbs.clear();
+    local_aabbs.clear();
+
+    if (forest->empty()) return;
+    if (!is_neighborhood_up_to_date) this->updateNeighborhood();
+
+    int block_idx = 0;
+    for(auto& iblock: *forest) {
+        auto block = static_cast<walberla::blockforest::Block *>(&iblock);
+        auto aabb = block->getAABB();
+        local_aabbs.push_back(aabb.xMin());
+        local_aabbs.push_back(aabb.xMax());
+        local_aabbs.push_back(aabb.yMin());
+        local_aabbs.push_back(aabb.yMax());
+        local_aabbs.push_back(aabb.zMin());
+        local_aabbs.push_back(aabb.zMax());
+        ++block_idx;
+    }
+
+    this->num_local_aabbs = block_idx;
+    non_empty_local_aabbs.resize(this->num_local_aabbs, 0);
+
+    determine_non_empty_aabbs(this->ps, this->num_local_aabbs, local_aabbs.data(), non_empty_local_aabbs.data());
+
+    std::map<int, bool> has_particles_for_rank;
+    block_idx = 0;
+    for(auto& iblock: *forest) {
+        auto block = static_cast<walberla::blockforest::Block *>(&iblock);
+        for(uint neigh = 0; neigh < block->getNeighborhoodSize(); ++neigh) {
+            auto neighbor_rank = walberla::int_c(block->getNeighborProcess(neigh));
+            has_particles_for_rank[neighbor_rank] = has_particles_for_rank[neighbor_rank] || non_empty_local_aabbs[block_idx];
+        }
+        ++block_idx;
+    }
+
+    for(auto& r: ranks) {
+        has_non_empty_aabb_in_neighborhood_of_rank.push_back(has_particles_for_rank[r]);
+    }
+
+    this->setBoundingBox();
+}
+
+void BlockForest::rebalance() {
+    if(balance_workload){
         if(!forest->loadBalancingFunctionRegistered()){
             std::cerr << "Workload balancer is not initialized." << std::endl;
             exit(-1);
         }
 
         this->updateWeights();
+
         const int nlocal = ps->getTrackedVariableAsInteger("nlocal");
         for(auto &prop: ps->getProperties()) {
             if(!prop.isVolatile()) {
-                const int ptypesize = get_proptype_size(prop.getType());
+                const size_t ptypesize = get_proptype_size(prop.getType());
                 ps->copyPropertyToHost(prop, pairs::WriteAfterRead, nlocal*ptypesize);
             }
         }
         
         // PAIRS_DEBUG("Rebalance\n");
         if (rank==0) std::cout << "Rebalance" << std::endl;
+
         forest->refresh(); 
+        is_neighborhood_up_to_date = false;
     }
-
-    this->updateNeighborhood();
-    this->setBoundingBox();
 }
 
 void BlockForest::initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) {
@@ -401,7 +479,7 @@ int BlockForest::isWithinSubdomain(real_t x, real_t y, real_t z) {
     return false;
 }
 
-void BlockForest::communicateSizes(int dim, const int *nsend, int *nrecv) {
+void BlockForest::communicateSizes(int, const int *nsend, int *nrecv) {
     std::vector<MPI_Request> send_requests;
     std::vector<MPI_Request> recv_requests;
     size_t nranks = 0;
@@ -420,15 +498,15 @@ void BlockForest::communicateSizes(int dim, const int *nsend, int *nrecv) {
     }
 
     if(!send_requests.empty()) {
-        MPI_Waitall(send_requests.size(), send_requests.data(), MPI_STATUSES_IGNORE);
+        MPI_Waitall((int)(send_requests.size()), send_requests.data(), MPI_STATUSES_IGNORE);
     }
     if(!recv_requests.empty()) {
-        MPI_Waitall(recv_requests.size(), recv_requests.data(), MPI_STATUSES_IGNORE);
+        MPI_Waitall((int)(recv_requests.size()), recv_requests.data(), MPI_STATUSES_IGNORE);
     }
 }
 
 void BlockForest::communicateData(
-    int dim, int elem_size,
+    int, int elem_size,
     const real_t *send_buf, const int *send_offsets, const int *nsend,
     real_t *recv_buf, const int *recv_offsets, const int *nrecv) {
 
@@ -456,24 +534,24 @@ void BlockForest::communicateData(
     }
 
     if(!send_requests.empty()) {
-        MPI_Waitall(send_requests.size(), send_requests.data(), MPI_STATUSES_IGNORE);
+        MPI_Waitall((int)(send_requests.size()), send_requests.data(), MPI_STATUSES_IGNORE);
     }
 
     if(!recv_requests.empty()) {
-        MPI_Waitall(recv_requests.size(), recv_requests.data(), MPI_STATUSES_IGNORE);
+        MPI_Waitall((int)(recv_requests.size()), recv_requests.data(), MPI_STATUSES_IGNORE);
     }
 }
 
 void BlockForest::communicateDataReverse(
-    int dim, int elem_size,
+    int, int elem_size,
     const real_t *send_buf, const int *send_offsets, const int *nsend,
     real_t *recv_buf, const int *recv_offsets, const int *nrecv) {
 
-        this->communicateData(dim, elem_size,send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
+        this->communicateData(0, elem_size,send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
 }
 
 void BlockForest::communicateAllData(
-    int ndims, int elem_size,
+    int, int elem_size,
     const real_t *send_buf, const int *send_offsets, const int *nsend,
     real_t *recv_buf, const int *recv_offsets, const int *nrecv) {
 
diff --git a/runtime/domain/block_forest.hpp b/runtime/domain/block_forest.hpp
index 039f86769c8eb3b20c39aee7dd22a179d9221707..2c7675589b9f1deb16b71b2a8fdd758b1efbbf1b 100644
--- a/runtime/domain/block_forest.hpp
+++ b/runtime/domain/block_forest.hpp
@@ -31,17 +31,21 @@ class PairsRuntime;
 
 class BlockForest : public DomainPartitioner {
 private:
+    PairsRuntime *ps;
     std::shared_ptr<walberla::mpi::MPIManager> mpiManager;
     std::shared_ptr<walberla::blockforest::BlockForest> forest;
     std::shared_ptr<walberla::blockforest::InfoCollection> info;
     std::vector<int> ranks;
-    std::vector<int> naabbs;
+    std::vector<double> local_aabbs;
+    std::vector<int> has_non_empty_aabb_in_neighborhood_of_rank;
+    std::vector<int> non_empty_local_aabbs;
+    std::vector<int> num_neigh_aabbs;
     std::vector<int> aabb_offsets;
-    std::vector<double> aabbs;
-    PairsRuntime *ps;
+    std::vector<double> neigh_aabbs;
     real_t *subdom;
-    int world_size, rank, nranks, total_aabbs;
+    int world_size, rank, num_neigh_ranks, total_num_neigh_aabbs, num_local_aabbs;
     bool balance_workload = false;
+    bool is_neighborhood_up_to_date = false;
 
 public:
     BlockForest(
@@ -57,21 +61,22 @@ public:
     void initialize(int *argc, char ***argv);
     void initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax);
 
-    void update();
+    void updateLocal();
+    void updateNeighborhood();
+    void rebalance();
     void finalize();
     int getWorldSize() const { return world_size; }
     int getRank() const { return rank; }
-    int getNumberOfNeighborRanks() { return this->nranks; }
-    int getNumberOfNeighborAABBs() { return this->total_aabbs; }
+    int getNumberOfNeighborRanks() { return this->num_neigh_ranks; }
+    int getNumberOfNeighborAABBs() { return this->total_num_neigh_aabbs; }
+    int getNumberOfLocalAABBs() { return this->num_local_aabbs; }
     double getSubdomMin(int dim) const { return subdom[2*dim + 0];}
     double getSubdomMax(int dim) const { return subdom[2*dim + 1];}
 
-    void updateNeighborhood();
     void updateWeights();
     walberla::math::Vector3<int> getBlockConfig();
     int getInitialRefinementLevel(int num_processes);
     void setBoundingBox();
-    void rebalance();
 
     int isWithinSubdomain(real_t x, real_t y, real_t z);
     void copyRuntimeArray(const std::string& name, void *dest, const int size);
diff --git a/runtime/domain/boundary_weights.cpp b/runtime/domain/boundary_weights.cpp
index 3a67d29386c5d1d11b057b3a6435a3c8d7496626..bf9b0428b473836e05157603fb0f77c36d7814ff 100644
--- a/runtime/domain/boundary_weights.cpp
+++ b/runtime/domain/boundary_weights.cpp
@@ -40,7 +40,34 @@ void compute_boundary_weights(
     //       And neighbor blocks are going to change after rebalancing.
     // const int nghost = ps->getTrackedVariableAsInteger("nghost");
     *comm_weight = 0;
+}
+
+void determine_non_empty_aabbs(PairsRuntime *ps, int num_aabbs, real_t *aabbs, int *non_empty_aabbs){
+    const int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
+    const int nlocal = ps->getTrackedVariableAsInteger("nlocal");
+    auto position_prop = ps->getPropertyByName("position");
+    auto flags_prop = ps->getPropertyByName("flags");
+
+    real_t *position_ptr = static_cast<real_t *>(position_prop.getHostPointer());
+    int *flags_ptr = static_cast<int *>(flags_prop.getHostPointer());
 
+    for(int i = 0; i < nlocal; ++i) {
+        if (pairs_host_interface::get_flags(flags_ptr, i) & (pairs::flags::INFINITE | pairs::flags::GLOBAL)) {
+            continue;
+        }
+
+        real_t pos_x = pairs_host_interface::get_position(position_ptr, i, 0, particle_capacity);
+        real_t pos_y = pairs_host_interface::get_position(position_ptr, i, 1, particle_capacity);
+        real_t pos_z = pairs_host_interface::get_position(position_ptr, i, 2, particle_capacity);
+        for(int n = 0; n < num_aabbs; ++n){
+            if( pos_x >= aabbs[n*6 + 0] && pos_x < aabbs[n*6 + 1] &&
+                pos_y >= aabbs[n*6 + 2] && pos_y < aabbs[n*6 + 3] &&
+                pos_z >= aabbs[n*6 + 4] && pos_z < aabbs[n*6 + 5]) {
+                    non_empty_aabbs[n] = true;
+                    break;
+            }
+        }
+    }
 }
 
 }
diff --git a/runtime/domain/boundary_weights.cu b/runtime/domain/boundary_weights.cu
index 191139fa245dd5104afdb4793fbe0f27cdaa4441..ede18edb080c68f14f2e68e5a959b4ede52474b7 100644
--- a/runtime/domain/boundary_weights.cu
+++ b/runtime/domain/boundary_weights.cu
@@ -105,4 +105,70 @@ void compute_boundary_weights(
     *comm_weight = 0;
 }
 
+__global__ void determine_non_empty_aabbs_kernel(int num_aabbs, real_t *aabbs, int *non_empty_aabbs, 
+                                            int nlocal, int particle_capacity, real_t *position_ptr, int *flags_ptr){
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    int stride = blockDim.x * gridDim.x;
+    for(int i=idx; i<nlocal; i+=stride) {
+        if (pairs_cuda_interface::get_flags(flags_ptr, i) & (pairs::flags::INFINITE | pairs::flags::GLOBAL)) {
+            continue;
+        }
+
+        real_t pos_x = pairs_cuda_interface::get_position(position_ptr, i, 0, particle_capacity);
+        real_t pos_y = pairs_cuda_interface::get_position(position_ptr, i, 1, particle_capacity);
+        real_t pos_z = pairs_cuda_interface::get_position(position_ptr, i, 2, particle_capacity);
+        for(int n = 0; n < num_aabbs; ++n){
+            if( pos_x >= aabbs[n*6 + 0] && pos_x < aabbs[n*6 + 1] &&
+                pos_y >= aabbs[n*6 + 2] && pos_y < aabbs[n*6 + 3] &&
+                pos_z >= aabbs[n*6 + 4] && pos_z < aabbs[n*6 + 5]) {
+                    non_empty_aabbs[n] = true;
+                    break;
+            }
+        }
+    }
+}
+
+void determine_non_empty_aabbs(PairsRuntime *ps, int num_aabbs, real_t *aabbs, int *non_empty_aabbs){  
+    const int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
+    const int nlocal = ps->getTrackedVariableAsInteger("nlocal");
+    
+    if (nlocal==0){
+        return;
+    }
+    else{
+        if (num_aabbs==1){
+            non_empty_aabbs[0] = 1;
+            return;
+        }
+    }
+    
+    auto position_prop = ps->getPropertyByName("position");
+    auto flags_prop = ps->getPropertyByName("flags");
+
+    real_t *position_ptr = static_cast<real_t *>(position_prop.getDevicePointer());
+    int *flags_ptr = static_cast<int *>(flags_prop.getDevicePointer());
+
+    ps->copyPropertyToDevice(position_prop.getId(), ReadOnly);
+    ps->copyPropertyToDevice(flags_prop.getId(), ReadOnly);
+
+    size_t aabbs_size = num_aabbs * 6 * sizeof(real_t);
+    size_t non_empty_aabbs_size = num_aabbs * sizeof(int);
+
+    real_t *aabbs_d = (real_t *) device_alloc(aabbs_size);
+    CUDA_ASSERT(cudaMemcpy(aabbs_d, aabbs, aabbs_size, cudaMemcpyHostToDevice));
+
+    int *non_empty_aabbs_d = (int *) device_alloc(non_empty_aabbs_size);
+    CUDA_ASSERT(cudaMemset(non_empty_aabbs_d, 0, non_empty_aabbs_size));
+
+    const int tpb = 64;
+    const int nblocks = (nlocal + tpb -1) / tpb;
+
+    determine_non_empty_aabbs_kernel<<< nblocks, tpb >>>(num_aabbs, aabbs_d, non_empty_aabbs_d, 
+        nlocal, particle_capacity, position_ptr, flags_ptr);
+
+    CUDA_ASSERT(cudaPeekAtLastError());
+    CUDA_ASSERT(cudaDeviceSynchronize());
+    CUDA_ASSERT(cudaMemcpy(non_empty_aabbs, non_empty_aabbs_d, non_empty_aabbs_size, cudaMemcpyDeviceToHost));
+}
+
 }
diff --git a/runtime/domain/boundary_weights.hpp b/runtime/domain/boundary_weights.hpp
index 3b7fd22e4132b243531e07ed885bc1883b229820..6fb404459ba386ca53ea351816ea9186dd66b39f 100644
--- a/runtime/domain/boundary_weights.hpp
+++ b/runtime/domain/boundary_weights.hpp
@@ -17,4 +17,6 @@ void compute_boundary_weights(
     real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax,
     long unsigned int *comp_weight, long unsigned int *comm_weight);
 
+void determine_non_empty_aabbs(PairsRuntime *ps, int num_aabbs, real_t *aabbs, int *non_empty_aabbs);
+
 }
diff --git a/runtime/domain/domain_partitioning.hpp b/runtime/domain/domain_partitioning.hpp
index 3dfdaaebfa8c9f91f58705fd1b750f53569afc04..8d41618e84722d2acbd89aed832d5c5c686cd1a3 100644
--- a/runtime/domain/domain_partitioning.hpp
+++ b/runtime/domain/domain_partitioning.hpp
@@ -43,10 +43,13 @@ public:
     virtual double getSubdomMax(int dim) const = 0;
     virtual void initialize(int *argc, char ***argv) = 0;
     virtual void initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) = 0;
-    virtual void update() = 0;
+    virtual void updateLocal() = 0;
+    virtual void updateNeighborhood() = 0;
+    virtual void rebalance() = 0;
     virtual int getWorldSize() const = 0;
     virtual int getRank() const = 0;
     virtual int getNumberOfNeighborAABBs() = 0;
+    virtual int getNumberOfLocalAABBs() = 0;
     virtual int getNumberOfNeighborRanks() = 0;
     virtual int isWithinSubdomain(real_t x, real_t y, real_t z) = 0;
     virtual void copyRuntimeArray(const std::string& name, void *dest, const int size) = 0;
diff --git a/runtime/domain/regular_6d_stencil.cpp b/runtime/domain/regular_6d_stencil.cpp
index 699ea93eb14d724e3e5c024b70ff8c10d16b6d19..b469af0170a6437c4ee6c46278ebaa43ffabf4a0 100644
--- a/runtime/domain/regular_6d_stencil.cpp
+++ b/runtime/domain/regular_6d_stencil.cpp
@@ -95,10 +95,6 @@ void Regular6DStencil::initialize(int *argc, char ***argv) {
     }
 }
 
-void Regular6DStencil::initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) {}
-
-void Regular6DStencil::update() {}
-
 void Regular6DStencil::finalize() {
     MPI_Finalize();
 }
diff --git a/runtime/domain/regular_6d_stencil.hpp b/runtime/domain/regular_6d_stencil.hpp
index b4a9e5c6634c6f15c89041f539f0b955ecce992f..f1ce90ce0c08fc7b38e5b8738299215bfd872bca 100644
--- a/runtime/domain/regular_6d_stencil.hpp
+++ b/runtime/domain/regular_6d_stencil.hpp
@@ -51,14 +51,17 @@ public:
     void setConfig();
     void setBoundingBox();
     void initialize(int *argc, char ***argv);
-    void initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax);
-    void update();
+    void initWorkloadBalancer(LoadBalancingAlgorithms, size_t, size_t) {}
+    void updateLocal() {}
+    void updateNeighborhood() {}
+    void rebalance() {}
     void finalize();
 
     int getWorldSize() const { return world_size; }
     int getRank() const { return rank; }
     int getNumberOfNeighborRanks() { return 6; }
     int getNumberOfNeighborAABBs() { return 6; }
+    int getNumberOfLocalAABBs() { return 1; }
     double getSubdomMin(int dim) const { return subdom_min[dim];}
     double getSubdomMax(int dim) const { return subdom_max[dim];}
 
diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index a2382820e441daae0163511d12038919aae329f4..8803b8ea71e72f1ba06b27f3bc7523ba0e265977 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -399,6 +399,8 @@ void PairsRuntime::communicateSizes(int dim, const int *send_sizes, int *recv_si
     array_flags->setHostFlag(nrecv_id);
     array_flags->clearDeviceFlag(nrecv_id);
 
+    // MPI_Barrier(MPI_COMM_WORLD);
+
     this->getTimers()->start(TimerMarkers::MPI);
     this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes);
     this->getTimers()->stop(TimerMarkers::MPI);
@@ -413,8 +415,6 @@ void PairsRuntime::communicateData(
     real_t *recv_buf_ptr = recv_buf;
     auto send_buf_array = getArrayByHostPointer(send_buf);
     auto recv_buf_array = getArrayByHostPointer(recv_buf);
-    auto send_buf_id = send_buf_array.getId();
-    auto recv_buf_id = recv_buf_array.getId();
     auto send_offsets_id = getArrayByHostPointer(send_offsets).getId();
     auto recv_offsets_id = getArrayByHostPointer(recv_offsets).getId();
     auto nsend_id = getArrayByHostPointer(nsend).getId();
@@ -447,11 +447,15 @@ void PairsRuntime::communicateData(
         }
     }
     
+    auto send_buf_id = send_buf_array.getId();
+    auto recv_buf_id = recv_buf_array.getId();
     copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
     #endif
 
+    // MPI_Barrier(MPI_COMM_WORLD);
+
     this->getTimers()->start(TimerMarkers::MPI);
     this->getDomainPartitioner()->communicateData(
         dim, elem_size, send_buf_ptr, send_offsets, nsend, recv_buf_ptr, recv_offsets, nrecv);
@@ -471,8 +475,6 @@ void PairsRuntime::communicateDataReverse(
     real_t *recv_buf_ptr = recv_buf;
     auto send_buf_array = getArrayByHostPointer(send_buf);
     auto recv_buf_array = getArrayByHostPointer(recv_buf);
-    auto send_buf_id = send_buf_array.getId();
-    auto recv_buf_id = recv_buf_array.getId();
     auto send_offsets_id = getArrayByHostPointer(send_offsets).getId();
     auto recv_offsets_id = getArrayByHostPointer(recv_offsets).getId();
     auto nsend_id = getArrayByHostPointer(nsend).getId();
@@ -505,11 +507,15 @@ void PairsRuntime::communicateDataReverse(
         }
     }
 
+    auto send_buf_id = send_buf_array.getId();
+    auto recv_buf_id = recv_buf_array.getId();
     copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
     #endif
 
+    // MPI_Barrier(MPI_COMM_WORLD);
+
     this->getTimers()->start(TimerMarkers::MPI);
     this->getDomainPartitioner()->communicateDataReverse(
         dim, elem_size, send_buf_ptr, send_offsets, nsend, recv_buf_ptr, recv_offsets, nrecv);
@@ -529,8 +535,6 @@ void PairsRuntime::communicateAllData(
     real_t *recv_buf_ptr = recv_buf;
     auto send_buf_array = getArrayByHostPointer(send_buf);
     auto recv_buf_array = getArrayByHostPointer(recv_buf);
-    auto send_buf_id = send_buf_array.getId();
-    auto recv_buf_id = recv_buf_array.getId();
     auto send_offsets_id = getArrayByHostPointer(send_offsets).getId();
     auto recv_offsets_id = getArrayByHostPointer(recv_offsets).getId();
     auto nsend_id = getArrayByHostPointer(nsend).getId();
@@ -563,11 +567,15 @@ void PairsRuntime::communicateAllData(
         }
     }
 
+    auto send_buf_id = send_buf_array.getId();
+    auto recv_buf_id = recv_buf_array.getId();
     copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
     #endif
 
+    // MPI_Barrier(MPI_COMM_WORLD);
+
     this->getTimers()->start(TimerMarkers::MPI);
     this->getDomainPartitioner()->communicateAllData(
         ndims, elem_size, send_buf_ptr, send_offsets, nsend, recv_buf_ptr, recv_offsets, nrecv);
@@ -587,12 +595,8 @@ void PairsRuntime::communicateContactHistoryData(
     real_t *recv_buf_ptr = recv_buf;
     auto send_buf_array = getArrayByHostPointer(send_buf);
     auto recv_buf_array = getArrayByHostPointer(recv_buf);
-    auto send_buf_id = send_buf_array.getId();
-    auto recv_buf_id = recv_buf_array.getId();
     auto contact_soffsets_id = getArrayByHostPointer(contact_soffsets).getId();
-    auto contact_roffsets_id = getArrayByHostPointer(contact_roffsets).getId();
     auto nsend_contact_id = getArrayByHostPointer(nsend_contact).getId();
-    auto nrecv_contact_id = getArrayByHostPointer(nrecv_contact).getId();
 
     copyArrayToHost(contact_soffsets_id, ReadOnly);
     copyArrayToHost(nsend_contact_id, ReadOnly);
@@ -609,11 +613,15 @@ void PairsRuntime::communicateContactHistoryData(
     send_buf_ptr = (real_t *) send_buf_array.getDevicePointer();
     recv_buf_ptr = (real_t *) recv_buf_array.getDevicePointer();
     #else
+    auto send_buf_id = send_buf_array.getId();
+    auto recv_buf_id = recv_buf_array.getId();
     copyArrayToHost(send_buf_id, Ignore, nsend_all * sizeof(real_t));
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
     #endif
 
+    // MPI_Barrier(MPI_COMM_WORLD);
+
     this->getTimers()->start(TimerMarkers::MPI);
     this->getDomainPartitioner()->communicateSizes(dim, nsend_contact, nrecv_contact);
 
@@ -634,6 +642,7 @@ void PairsRuntime::communicateContactHistoryData(
     this->getTimers()->stop(TimerMarkers::MPI);
 
     #ifndef ENABLE_CUDA_AWARE_MPI
+    auto contact_roffsets_id = getArrayByHostPointer(contact_roffsets).getId();
     copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * sizeof(real_t));
     copyArrayToDevice(contact_roffsets_id, Ignore);
     #endif
@@ -653,7 +662,11 @@ void PairsRuntime::allReduceInplaceSum(real_t *red_buffer, int num_elems){
     copyArrayToHost(buff_array, Ignore, num_elems * sizeof(real_t));
     #endif
 
+    // MPI_Barrier(MPI_COMM_WORLD);
+    
+    this->getTimers()->start(TimerMarkers::MPI);
     MPI_Allreduce(MPI_IN_PLACE, buff_ptr, num_elems, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
+    this->getTimers()->stop(TimerMarkers::MPI);
 
     #ifndef ENABLE_CUDA_AWARE_MPI
     copyArrayToDevice(buff_array, Ignore, num_elems * sizeof(real_t));
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index 853e8ef49408ccceb309a31083c6786550eb07d1..31540997c862f9458ff7e3e59dae7eeaa6c853d6 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -326,8 +326,6 @@ public:
     template<typename Domain_T>
     void useDomain(const std::shared_ptr<Domain_T> &domain_ptr);
 
-    void updateDomain() { dom_part->update(); }
-
     DomainPartitioner *getDomainPartitioner() { return dom_part; }
     void communicateSizes(int dim, const int *send_sizes, int *recv_sizes);
 
diff --git a/runtime/timers.hpp b/runtime/timers.hpp
index 4c7c3dbccf74f125fddbce854d3768bc019e54ee..bb73df8f6ba9872923aa3dd8ec2784dce92a3685 100644
--- a/runtime/timers.hpp
+++ b/runtime/timers.hpp
@@ -78,7 +78,7 @@ public:
         ss << "\n\n";
 
         std::string output = ss.str();
-        MPI_File_write_ordered(file, output.c_str(), output.size(), MPI_CHAR, MPI_STATUS_IGNORE);
+        MPI_File_write_ordered(file, output.c_str(), (int)(output.size()), MPI_CHAR, MPI_STATUS_IGNORE);
         MPI_File_close(&file);
     }
 
diff --git a/runtime/utility/vtk.cpp b/runtime/utility/vtk.cpp
index 57de0eb5489594c6099969d881187c98989f3bb4..3907437b4b24ee394e66f1d7847c6c6df3e0c3d9 100644
--- a/runtime/utility/vtk.cpp
+++ b/runtime/utility/vtk.cpp
@@ -7,6 +7,58 @@
 namespace pairs {
 
 
+void vtk_write_halo_cells(PairsRuntime *ps, const char *filename, int timestep, 
+        int nhalo_cells, int *halo_cells, int *dim_cells, double *spacing, double *subdom){
+
+    // if(ts%20 == 0)
+    // vtk_write_halo_cells(pairs_runtime, "output/halo_cells", ts, 
+    //     pobj->halo_ncells, pobj->halo_cells, pobj->dim_cells, pobj->spacing, pobj->subdom);
+
+    std::ostringstream filename_oss;
+
+    filename_oss << filename << "_";
+    if(ps->getDomainPartitioner()->getWorldSize() > 1) {
+        filename_oss << "r" << ps->getDomainPartitioner()->getRank() << "_";
+    }
+
+    filename_oss << timestep << ".vtk";
+    std::ofstream out_file(filename_oss.str());
+
+    if(out_file.is_open()) {
+        out_file << "# vtk DataFile Version 2.0\n";
+        out_file << "Halo cells\n";
+        out_file << "ASCII\n";
+        out_file << "DATASET POLYDATA\n";
+        out_file << "POINTS " << nhalo_cells-1 << " double\n";
+
+        out_file << std::fixed << std::setprecision(6);
+        int halo_idx = 1;
+        for(int i = 0; i < dim_cells[0]; i++) {
+            for(int j = 0; j < dim_cells[1]; j++) {
+                for(int k = 0; k < dim_cells[2]; k++) {
+                    int flat_idx = i*dim_cells[1]*dim_cells[2] + j*dim_cells[2] + k + 1;
+                    if(halo_cells[halo_idx] == flat_idx){
+                        // Cell centers:
+                        out_file << (i-0.5)*spacing[0] + subdom[0] << " ";
+                        out_file << (j-0.5)*spacing[1] + subdom[2] << " ";
+                        out_file << (k-0.5)*spacing[2] + subdom[4] << "\n";
+                        ++halo_idx;
+                    }
+                }
+            }
+        }
+
+        out_file << "\n\n";
+        out_file.close();
+    }
+    else {
+        std::cerr << "Failed to open " << filename_oss.str() << std::endl;
+        exit(-1);
+    }
+
+}
+
+
 void vtk_with_rotation(
     PairsRuntime *ps, Shapes shape, const char *filename, int start, int end, int timestep, int frequency) {
 
diff --git a/runtime/utility/vtk.hpp b/runtime/utility/vtk.hpp
index ac369d5c776f06ae626ab142e16e79dd74b81c7e..62d6042c45dcf06d6b7c261620734c0aeea93233 100644
--- a/runtime/utility/vtk.hpp
+++ b/runtime/utility/vtk.hpp
@@ -7,6 +7,9 @@ namespace pairs {
 void vtk_with_rotation(
     PairsRuntime *ps, Shapes shape, const char *filename, int start, int end, int timestep, int frequency=1);
 
+void vtk_write_halo_cells(PairsRuntime *ps, const char *filename, int timestep, 
+    int nhalo_cells, int *halo_cells, int *dim_cells, double *spacing, double *subdom);
+
 void vtk_write_aabb(PairsRuntime *ps, const char *filename, int num,
     double xmin, double xmax, 
     double ymin, double ymax, 
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 9a07c9444bb1a16d1c7628b3a3d6decd898d11af..46276312d1b5272316efd2895dbfef440fc16ae6 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -15,7 +15,7 @@ from pairs.ir.functions import Call
 from pairs.ir.kernel import KernelLaunch
 from pairs.ir.layouts import Layouts
 from pairs.ir.lit import Lit
-from pairs.ir.loops import For, Iter, While, Continue
+from pairs.ir.loops import For, Iter, While, Continue, Break
 from pairs.ir.quaternions import Quaternion, QuaternionAccess, QuaternionOp
 from pairs.ir.math import MathFunction
 from pairs.ir.matrices import Matrix, MatrixAccess, MatrixOp
@@ -542,6 +542,12 @@ class CGen:
             else:
                 self.print("return;")
 
+        if isinstance(ast_node, Break):
+            if self.loop_scope:
+                self.print("break;")
+            else:
+                self.print("return;")
+
         # TODO: Why there are Decls for other types?
         if isinstance(ast_node, Decl):
             if isinstance(ast_node.elem, ArrayAccess):
diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
index b2631572552f28063c060accba929f3e27d7335f..9a67d659c1734c7b15bce67875023cb016546f28 100644
--- a/src/pairs/code_gen/interface.py
+++ b/src/pairs/code_gen/interface.py
@@ -3,7 +3,7 @@ from pairs.ir.functions import Call_Void, Call, Call_Int
 from pairs.ir.parameters import Parameter
 from pairs.ir.ret import Return
 from pairs.ir.scalars import ScalarOp
-from pairs.sim.domain import UpdateDomain
+from pairs.sim.domain import DomainRebalance, DomainUpdateLocal, DomainUpdateNeighborhood
 from pairs.sim.cell_lists import BuildCellListsStencil
 from pairs.sim.comm import Synchronize, Borders, Exchange, ReverseComm
 from pairs.ir.types import Types
@@ -27,17 +27,18 @@ class InterfaceModules:
 
     def create_all(self):
         self.initialize()
-        self.setup_cells()
-        self.update_domain()
-        self.update_cells(self.sim.reneighbor_frequency) 
-        self.communicate(self.sim.reneighbor_frequency)
-        self.reverse_comm() 
-        self.reset_volatiles()
+        self.setCellWidth()
+        self.setInteractionRadius()
+        self.updateDomain()
+        self.reneighbor()
+        self.refreshGhosts() 
+        self.reverseCommunicate() 
+        self.resetVolatiles()
 
         if self.sim._use_contact_history:
             if self.neighbor_lists:
-                self.build_contact_history(self.sim.reneighbor_frequency)
-            self.reset_contact_history()
+                self.buildContactHistory(self.sim.reneighbor_frequency)
+            self.resetContactHistory()
 
         self.rank()
         self.nlocal()
@@ -75,76 +76,85 @@ class InterfaceModules:
         self.sim.add_statement(inits)
 
     @pairs_interface_block
-    def setup_cells(self):
-        self.sim.module_name('setup_cells')
-        
+    def setCellWidth(self):
+        self.sim.module_name('setCellWidth')
         if self.sim.cell_lists.runtime_spacing:
             for d in range(self.sim.dims):
-                Assign(self.sim, self.sim.cell_lists.spacing[d], Parameter(self.sim, f'cell_spacing_d{d}', Types.Real))
+                Assign(self.sim, self.sim.cell_lists.spacing[d], Parameter(self.sim, f'cell_width_dim_{d}', Types.Real))
 
+    @pairs_interface_block
+    def setInteractionRadius(self):
+        self.sim.module_name('setInteractionRadius')
         if self.sim.cell_lists.runtime_cutoff_radius:
             Assign(self.sim, self.sim.cell_lists.cutoff_radius, Parameter(self.sim, 'cutoff_radius', Types.Real))
-
-        # This update assumes all particles have been created exactly in the rank that contains them 
-        self.sim.add_statement(UpdateDomain(self.sim))  
-        self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))
-        self.sim.add_statement(self.sim.update_cells_procedures)
-        
-    @pairs_interface_block
-    def update_domain(self):
-        self.sim.module_name('update_domain')
-        self.sim.add_statement(Exchange(self.sim._comm))    # Local particles must be contained in their owners before domain update
-        self.sim.add_statement(UpdateDomain(self.sim))
-        # Exchange is not needed after update since all locals are contained in thier owners
-        self.sim.add_statement(Borders(self.sim._comm))     # Ghosts must be recreated after update
-        self.sim.add_statement(ResetVolatileProperties(self.sim))   # Reset volatile includes the new locals
-        self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))    # Rebuild stencil since subdom sizes have changed
-        self.sim.add_statement(self.sim.update_cells_procedures)
         
     @pairs_interface_block
-    def reset_volatiles(self):
-        self.sim.module_name('reset_volatiles')
-        self.sim.add_statement(ResetVolatileProperties(self.sim))
-    
-    @pairs_interface_block
-    def update_cells(self, reneighbor_frequency=1):
-        self.sim.module_name('update_cells')
-        timestep = Parameter(self.sim, f'timestep', Types.Int32)
-        cond = ScalarOp.inline(ScalarOp.or_op(
-            ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
-            ScalarOp.cmp(timestep, 0)
-            ))
+    def updateDomain(self):
+        ''' This function is required to be called only once after all particles have been created.
+        If rebalancing is enabled, the domain is rebalanced everytime this function is called.
+        If rebalancing is disabled, calling this function has the same effect as calling 'reneighbor'. 
+        '''
+        self.sim.module_name('updateDomain')
+
+        self.sim.add_statement(DomainUpdateNeighborhood(self.sim)) 
+
+        # Local particles must be contained in their owners before rebalancing, otherwise they may get lost
+        self.sim.add_statement(Exchange(self.sim._comm))
+
+        # Here AABBs assigned to each rank may change if rebalancing is enabled
+        self.sim.add_statement(DomainRebalance(self.sim)) 
+
+        # This is a cheap update to crop the subdom and find local non-empty AABBs
+        # Note: All local particles are strictly contained within AABBs. Therefore, 
+        # no padding is needed to find non-empty AABBs
+        self.sim.add_statement(DomainUpdateLocal(self.sim))      
+
+        # Rebuild stencil since subdom sizes have changed. Also may use non-empty AABBs to create halo cells
+        self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists)) 
+
+        # Populate cells with local and ghost particles
+        self.sim.add_statement(self.sim.update_cells_procedures)   
+
+        # Exchange is not needed all locals are contained in thier owners after deserialization
+        # But ghosts must be recreated after rebalancing (optionally uses the halo cells)
+        self.sim.add_statement(Borders(self.sim._comm))
+
+        # Reset volatile includes the new locals
+        self.sim.add_statement(ResetVolatileProperties(self.sim))  
 
-        self.sim.add_statement(Filter(self.sim, cond, self.sim.update_cells_procedures))
 
     @pairs_interface_block
-    def communicate(self, reneighbor_frequency=1):
-        self.sim.module_name('communicate')
-        timestep = Parameter(self.sim, f'timestep', Types.Int32)
-        cond = ScalarOp.inline(ScalarOp.or_op(
-            ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
-            ScalarOp.cmp(timestep, 0)
-            ))
-        
-        exchange = Filter(self.sim, cond, Exchange(self.sim._comm))
-        border_sync = Branch(self.sim, cond, 
-                             blk_if = Borders(self.sim._comm), 
-                             blk_else = Synchronize(self.sim._comm))
-        
-        self.sim.add_statement(exchange)
-        self.sim.add_statement(border_sync)
-        
-        # TODO: Maybe remove this from here, but volatiles must always be reset after exchange
-        self.sim.add_statement(Filter(self.sim, cond, Block(self.sim, ResetVolatileProperties(self.sim))))   
+    def reneighbor(self):
+        self.sim.module_name('reneighbor')
+        reneighboring_procedures = Block.from_list(self.sim, [
+            Exchange(self.sim._comm),
+            Borders(self.sim._comm),
+            # Note: DomainUpdateLocal must happen after exchange since local particles must be contained in AABBs
+            DomainUpdateLocal(self.sim),    
+            BuildCellListsStencil(self.sim, self.sim.cell_lists),
+            self.sim.update_cells_procedures,
+            ResetVolatileProperties(self.sim)
+        ])
+        self.sim.add_statement(reneighboring_procedures)
+
+    @pairs_interface_block
+    def refreshGhosts(self):
+        self.sim.module_name('refreshGhosts')
+        self.sim.add_statement(Synchronize(self.sim._comm))
 
     @pairs_interface_block
-    def reverse_comm(self):
-        self.sim.module_name('reverse_comm')
+    def reverseCommunicate(self):
+        self.sim.module_name('reverseCommunicate')
         self.sim.add_statement(ReverseComm(self.sim._comm, reduce=True))
     
     @pairs_interface_block
-    def build_contact_history(self, reneighbor_frequency=1):
-        self.sim.module_name('build_contact_history')
+    def resetVolatiles(self):
+        self.sim.module_name('resetVolatiles')
+        self.sim.add_statement(ResetVolatileProperties(self.sim))
+    
+    @pairs_interface_block
+    def buildContactHistory(self, reneighbor_frequency=1):
+        self.sim.module_name('buildContactHistory')
         timestep = Parameter(self.sim, f'timestep', Types.Int32)
         cond = ScalarOp.inline(ScalarOp.or_op(
             ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
@@ -156,8 +166,8 @@ class InterfaceModules:
                    BuildContactHistory(self.sim, self.sim._contact_history, self.sim.cell_lists)))
 
     @pairs_interface_block
-    def reset_contact_history(self):
-        self.sim.module_name('reset_contact_history')
+    def resetContactHistory(self):
+        self.sim.module_name('resetContactHistory')
         self.sim.add_statement(ResetContactHistoryUsageStatus(self.sim, self.sim._contact_history))
         self.sim.add_statement(ClearUnusedContactHistory(self.sim, self.sim._contact_history))
 
diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 7d90c3c97e3fac7d4cb7cdf48b9d8b1d40720808..78e7a6685c1cc6903daf4e696c2f94ecc9097f46 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -133,3 +133,13 @@ class Continue(ASTNode):
 
     def __call__(self):
         self.sim.add_statement(self)
+
+class Break(ASTNode):
+    def __init__(self, sim):
+        super().__init__(sim)
+
+    def __str__(self):
+        return f"Break<>"
+
+    def __call__(self):
+        self.sim.add_statement(self)
\ No newline at end of file
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index d4e571ea2891ab9f388e659ad2886bebb4fb4026..1043d39e0b552ffda6fcb08de58c8c34392efec2 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -6,7 +6,7 @@ from pairs.ir.atomic import AtomicAdd
 from pairs.ir.block import pairs_device_block, pairs_host_block
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.cast import Cast
-from pairs.ir.loops import For, ParticleFor, While
+from pairs.ir.loops import For, ParticleFor, While, Break
 from pairs.ir.math import Ceil
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.select import Select
@@ -108,7 +108,6 @@ class BuildCellListsStencil(Lowerable):
                         Assign(self.sim, nstencil, nstencil + 1)
 
         # Halo cell generation
-        # TODO: Defer halo cells generation to dom_part
         # ----------------------------------------------------
         if self.sim._use_halo_cells:
             halo_ncells_capacity = self.cell_lists.halo_ncells_capacity
@@ -117,26 +116,21 @@ class BuildCellListsStencil(Lowerable):
             halo_cells = self.cell_lists.halo_cells
             Assign(self.sim, n, 1)
 
-            # Note: We add +2 to each layer since it's possible that the outermost local layer of  
-            # master doesn't fully overlap with innermost ghost layer of neighbor, and vice versa.
-            layers_0 = self.sim.add_temp_var(0) 
-            Assign(self.sim, layers_0, Ceil(self.sim, (cutoff_radius / spacing[0])) + 2)
-            layers_1 = self.sim.add_temp_var(0)
-            Assign(self.sim, layers_1, Ceil(self.sim, (cutoff_radius / spacing[1])) + 2)
-            layers_2 = self.sim.add_temp_var(0)
-            Assign(self.sim, layers_2, Ceil(self.sim, (cutoff_radius / spacing[2])) + 2)
+            # Note: We add one layer to each side of the border since it's possible that the outermost local 
+            # layer of master doesn't fully overlap with innermost ghost layer of neighbor, and vice versa.
+            layers = [self.sim.add_temp_var(0) for _ in range(self.sim.ndims())]
+            for d in range(self.sim.ndims()):
+                Assign(self.sim, layers[d], Ceil(self.sim, (cutoff_radius / spacing[d])))
         
             for x in For(self.sim, 0, dim_ncells[0]):
                 for y in For(self.sim, 0, dim_ncells[1]):
                     for z in For(self.sim, 0, dim_ncells[2]):
-                        cond0 = ScalarOp.or_op(x<layers_0, x>=(dim_ncells[0]-layers_0)) 
-                        cond1 = ScalarOp.or_op(y<layers_1, y>=(dim_ncells[1]-layers_1)) 
-                        cond2 = ScalarOp.or_op(z<layers_2, z>=(dim_ncells[2]-layers_2)) 
-                        fullcond = ScalarOp.or_op(ScalarOp.or_op(cond0, cond1), cond2) 
-                        for _ in Filter(self.sim, fullcond):
-                            index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
-                            Assign(self.sim, halo_cells[n], index + 1)
-                            Assign(self.sim, n, n+1)
+                        for is_halo in self.cell_lists.dom_part.halo_condition(x, y, z, spacing, layers):
+                            for _ in Filter(self.sim, is_halo):
+                                index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                                Assign(self.sim, halo_cells[n], index + 1)
+                                Assign(self.sim, n, n+1)
+                                Break(self.sim)()   # Go to next cell
 
 
 class BuildCellLists(Lowerable):
@@ -197,6 +191,8 @@ class PartitionCellLists(Lowerable):
         cell_particles = self.cell_lists.cell_particles
         shapes_buffer = self.cell_lists.shapes_buffer
 
+        # for p in ParticleFor(self.sim, local_only=False):
+        #     cell = particle_cell[p]
         for cell in For(self.sim, 0, self.cell_lists.ncells):
             start = self.sim.add_temp_var(0)
             end = self.sim.add_temp_var(0)
diff --git a/src/pairs/sim/domain.py b/src/pairs/sim/domain.py
index dfa73de79e07c2deb4476c3ae4e4f8902c685b63..d10081fafb62debfc5b206af134682ef538f9471 100644
--- a/src/pairs/sim/domain.py
+++ b/src/pairs/sim/domain.py
@@ -1,10 +1,28 @@
 from pairs.ir.block import pairs_inline
 from pairs.sim.lowerable import Lowerable
 
-class UpdateDomain(Lowerable):
+class DomainUpdateLocal(Lowerable):
     def __init__(self, sim):
         super().__init__(sim)
 
     @pairs_inline
     def lower(self):
-        self.sim.domain_partitioning().update()
+        self.sim.domain_partitioning().update_local()
+
+
+class DomainUpdateNeighborhood(Lowerable):
+    def __init__(self, sim):
+        super().__init__(sim)
+
+    @pairs_inline
+    def lower(self):
+        self.sim.domain_partitioning().update_neighborhood()
+
+
+class DomainRebalance(Lowerable):
+    def __init__(self, sim):
+        super().__init__(sim)
+
+    @pairs_inline
+    def lower(self):
+        self.sim.domain_partitioning().rebalance()
\ No newline at end of file
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index 0229e19ddedf8dd5e6024acafdbc9de2be707d63..ca318ece69848c04ef864ad34649c393e208d247 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -12,6 +12,8 @@ from pairs.ir.device import CopyArray
 from pairs.ir.contexts import Contexts
 from pairs.ir.actions import Actions
 from pairs.ir.print import Print
+from pairs.ir.cast import Cast
+from pairs.ir.math import Abs
 
 class DimensionRanges:
     def __init__(self, sim):
@@ -43,10 +45,12 @@ class DimensionRanges:
         return sum([array[i] for i in range(total_size)])
 
     def reduce_sum_step_indexes(self, step, array):
-       return sum([array[i] for i in self.step_indexes(step)])
+        return sum([array[i] for i in self.step_indexes(step)])
 
-    def update(self):
-        Call_Void(self.sim, "pairs_runtime->updateDomain", [])
+    def halo_condition(self, x, y, z, spacing, layers):
+        raise Exception("Regular6DStencil does not support halo cells yet.")
+    
+    def update_neighborhood(self):
         Assign(self.sim, self.rank, Call_Int(self.sim, "pairs_runtime->getDomainPartitioner()->getRank", []))
 
         Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['neighbor_ranks', self.neighbor_ranks, self.sim.ndims() * 2])
@@ -58,6 +62,12 @@ class DimensionRanges:
                 Assign(self.sim, self.sim.grid.min(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMin", [d], Types.Real))
                 Assign(self.sim, self.sim.grid.max(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMax", [d], Types.Real))
 
+    def rebalance(self):
+        pass
+
+    def update_local(self):
+        pass
+
     def ghost_particles(self, step, position, offset=0.0):
         # Particles with one of the following flags are ignored
         flags_to_exclude = (Flags.Infinite | Flags.Global)
@@ -154,20 +164,25 @@ class DimensionRanges:
 
 class BlockForest:
     def __init__(self, sim):
-        self.sim                = sim
-        self.reduce_step        = sim.add_var('reduce_step', Types.Int32)   # this var is treated as a tmp (workaround for gpu)
+        self.sim                    = sim
+        self.reduce_step            = sim.add_var('reduce_step', Types.Int32)   # this var is treated as a tmp (workaround for gpu)
         self.reduce_step.force_read = True
-        self.rank               = sim.add_var('rank', Types.Int32)
-        self.nranks             = sim.add_var('nranks', Types.Int32)
-        self.nranks_capacity    = sim.add_var('nranks_capacity', Types.Int32, init_value=27)
-        self.ntotal_aabbs       = sim.add_var('ntotal_aabbs', Types.Int32)
-        self.aabb_capacity      = sim.add_var('aabb_capacity', Types.Int32, init_value=27)
-        self.ranks              = sim.add_array('ranks', [self.nranks_capacity], Types.Int32)
-        self.naabbs             = sim.add_array('naabbs', [self.nranks_capacity], Types.Int32)
-        self.aabb_offsets       = sim.add_array('aabb_offsets', [self.nranks_capacity], Types.Int32)
-        self.aabbs              = sim.add_array('aabbs', [self.aabb_capacity, 6], Types.Real)
-        self.subdom             = sim.add_array('subdom', [sim.ndims() * 2], Types.Real)
-
+        self.rank                   = sim.add_var('rank', Types.Int32)
+        self.nranks                 = sim.add_var('nranks', Types.Int32)
+        self.nranks_capacity        = sim.add_var('nranks_capacity', Types.Int32, init_value=27)
+        self.total_num_neigh_aabbs  = sim.add_var('total_num_neigh_aabbs', Types.Int32)
+        self.num_local_aabbs        = sim.add_var('num_local_aabbs', Types.Int32)
+        self.neigh_aabb_capacity    = sim.add_var('neigh_aabb_capacity', Types.Int32, init_value=27)
+        self.local_aabb_capacity    = sim.add_var('local_aabb_capacity', Types.Int32, init_value=1)
+        self.ranks                  = sim.add_array('ranks', [self.nranks_capacity], Types.Int32)
+        self.num_neigh_aabbs        = sim.add_array('num_neigh_aabbs', [self.nranks_capacity], Types.Int32)
+        self.aabb_offsets           = sim.add_array('aabb_offsets', [self.nranks_capacity], Types.Int32)
+        self.neigh_aabbs            = sim.add_array('neigh_aabbs', [self.neigh_aabb_capacity, 6], Types.Real)
+        self.local_aabbs            = sim.add_array('local_aabbs', [self.local_aabb_capacity, 6], Types.Real)
+        self.non_empty_local_aabbs  = sim.add_array('non_empty_local_aabbs', [self.local_aabb_capacity], Types.Int32)
+        self.subdom                 = sim.add_array('subdom', [sim.ndims() * 2], Types.Real)
+        self.has_non_empty_aabb_in_neighborhood_of_rank = sim.add_array('has_non_empty_aabb_in_neighborhood_of_rank', [self.nranks_capacity], Types.Int32)
+    
     def min(self, dim):
         return self.subdom[dim * 2 + 0]
 
@@ -192,12 +207,129 @@ class BlockForest:
             Assign(self.sim, self.reduce_step, ScalarOp.inline( self.reduce_step + array[i]))
             
         return self.reduce_step
+    
+    def halo_condition(self, x, y, z, spacing, layers):
+        # [Experimental option] Can reduce the number of halo cells generated, but comes with overhead
+        optimize_paddings = self.sim._optimize_halo_paddings
+
+        for aabb_id in For(self.sim, 0, self.num_local_aabbs):
+            for _ in Filter(self.sim, self.non_empty_local_aabbs[aabb_id]):
+                aabb = [self.local_aabbs[aabb_id][i] for i in range(6)]
+
+                # Index meaning: halo_[dim][dim_min/dim_max][min/max for each padding]
+                if optimize_paddings:
+                    tol = 1e-6
+
+                    # Padding X
+                    #---------------------------------------------------------------------
+                    float_00 = (aabb[0] - (self.min(0) - spacing[0])) / spacing[0]
+                    int_00 = Cast.int(self.sim, float_00)
+                    sub_00 = tol < Abs(self.sim, int_00 - float_00 + layers[0])
+                    add_00 = tol < Abs(self.sim, int_00 - float_00)
+                    halo_000 = Select(self.sim, sub_00, int_00 - layers[0], int_00)
+                    halo_001 = Select(self.sim, add_00, int_00 + layers[0], int_00)
+
+
+                    float_01 = (aabb[1] - (self.min(0) - spacing[0])) / spacing[0]
+                    int_01 = Cast.int(self.sim, float_01)
+                    sub_01 = tol < Abs(self.sim, int_01 - float_01 + layers[0])
+                    add_01 = tol < Abs(self.sim, int_01 - float_01)
+                    halo_010 = Select(self.sim, sub_01, int_01 - layers[0], int_01)
+                    halo_011 = Select(self.sim, add_01, int_01 + layers[0], int_01)
+
+                    # Padding Y
+                    #---------------------------------------------------------------------
+                    float_10 = (aabb[2] - (self.min(1) - spacing[1])) / spacing[1]
+                    int_10 = Cast.int(self.sim, float_10)
+                    sub_10 = tol < Abs(self.sim, int_10 - float_10 + layers[1])
+                    add_10 = tol < Abs(self.sim, int_10 - float_10)
+                    halo_100 = Select(self.sim, sub_10, int_10 - layers[1], int_10)
+                    halo_101 = Select(self.sim, add_10, int_10 + layers[1], int_10)
+
+                    float_11 = (aabb[3] - (self.min(1) - spacing[1])) / spacing[1]
+                    int_11 = Cast.int(self.sim, float_11)
+                    sub_11 = tol < Abs(self.sim, int_11 - float_11 + layers[1])
+                    add_11 = tol < Abs(self.sim, int_11 - float_11)
+                    halo_110 = Select(self.sim, sub_11, int_11 - layers[1], int_11)
+                    halo_111 = Select(self.sim, add_11, int_11 + layers[1], int_11)
+
+                    # Padding Z
+                    #---------------------------------------------------------------------
+                    float_20 = (aabb[4] - (self.min(2) - spacing[2])) / spacing[2]
+                    int_20 = Cast.int(self.sim, float_20)
+                    sub_20 = tol < Abs(self.sim, int_20 - float_20 + layers[2])
+                    add_20 = tol < Abs(self.sim, int_20 - float_20)
+                    halo_200 = Select(self.sim, sub_20, int_20 - layers[2], int_20)
+                    halo_201 = Select(self.sim, add_20, int_20 + layers[2], int_20)
+
+                    float_21 = (aabb[5] - (self.min(2) - spacing[2])) / spacing[2]
+                    int_21 = Cast.int(self.sim, float_21)
+                    sub_21 = tol < Abs(self.sim, int_21 - float_21 + layers[2])
+                    add_21 = tol < Abs(self.sim, int_21 - float_21)
+                    halo_210 = Select(self.sim, sub_21, int_21 - layers[2], int_21)
+                    halo_211 = Select(self.sim, add_21, int_21 + layers[2], int_21)
+
+                else:
+                    halo_000 = Cast.int(self.sim, (aabb[0] - (self.min(0) - spacing[0])) / spacing[0]) - layers[0]
+                    halo_001 = Cast.int(self.sim, (aabb[0] - (self.min(0) - spacing[0])) / spacing[0]) + layers[0]
+                    halo_010 = Cast.int(self.sim, (aabb[1] - (self.min(0) - spacing[0])) / spacing[0]) - layers[0]
+                    halo_011 = Cast.int(self.sim, (aabb[1] - (self.min(0) - spacing[0])) / spacing[0]) + layers[0]
+                    
+                    halo_100 = Cast.int(self.sim, (aabb[2] - (self.min(1) - spacing[1])) / spacing[1]) - layers[1]
+                    halo_101 = Cast.int(self.sim, (aabb[2] - (self.min(1) - spacing[1])) / spacing[1]) + layers[1]
+                    halo_110 = Cast.int(self.sim, (aabb[3] - (self.min(1) - spacing[1])) / spacing[1]) - layers[1]
+                    halo_111 = Cast.int(self.sim, (aabb[3] - (self.min(1) - spacing[1])) / spacing[1]) + layers[1]
+                    
+                    halo_200 = Cast.int(self.sim, (aabb[4] - (self.min(2) - spacing[2])) / spacing[2]) - layers[2]
+                    halo_201 = Cast.int(self.sim, (aabb[4] - (self.min(2) - spacing[2])) / spacing[2]) + layers[2]
+                    halo_210 = Cast.int(self.sim, (aabb[5] - (self.min(2) - spacing[2])) / spacing[2]) - layers[2]
+                    halo_211 = Cast.int(self.sim, (aabb[5] - (self.min(2) - spacing[2])) / spacing[2]) + layers[2]
+
+                c0 = ScalarOp.and_op(x >= halo_000, x <= halo_011)
+                c1 = ScalarOp.and_op(y >= halo_100, y <= halo_111)
+                c2 = ScalarOp.and_op(z >= halo_200, z <= halo_211)
+                
+                cell_is_within_padded_aabb = ScalarOp.and_op(ScalarOp.and_op(c0, c1), c2)
+
+                for _ in Filter(self.sim, cell_is_within_padded_aabb):
+                    cond_0 = ScalarOp.or_op(x >= halo_010, x <= halo_001)
+                    cond_1 = ScalarOp.or_op(y >= halo_110, y <= halo_101)
+                    cond_2 = ScalarOp.or_op(z >= halo_210, z <= halo_201)
+                    yield ScalarOp.or_op(ScalarOp.or_op(cond_0, cond_1), cond_2) 
+
+    def update_local(self):
+        Call_Void(self.sim, "pairs_runtime->getDomainPartitioner()->updateLocal", [])
+        Assign(self.sim, self.num_local_aabbs, Call_Int(self.sim, "pairs_runtime->getDomainPartitioner()->getNumberOfLocalAABBs", []))
+        
+        for _ in Filter(self.sim, self.local_aabb_capacity < self.num_local_aabbs):
+            Assign(self.sim, self.local_aabb_capacity, self.num_local_aabbs + 4)
+            for arr in self.local_aabb_capacity.bonded_arrays():
+                    arr.realloc()
+
+        for _ in Filter(self.sim, ScalarOp.neq(self.nranks, 0)):
+            if self.sim._target.is_gpu():
+                CopyArray(self.sim, self.local_aabbs, Contexts.Host, Actions.WriteOnly, self.num_local_aabbs * 6)
+                CopyArray(self.sim, self.non_empty_local_aabbs, Contexts.Host, Actions.WriteOnly, self.num_local_aabbs)
+                CopyArray(self.sim, self.has_non_empty_aabb_in_neighborhood_of_rank, Contexts.Host, Actions.WriteOnly, self.nranks)
 
-    def update(self):
-        Call_Void(self.sim, "pairs_runtime->updateDomain", [])
+            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['local_aabbs', self.local_aabbs, self.num_local_aabbs * 6])
+            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['non_empty_local_aabbs', self.non_empty_local_aabbs, self.num_local_aabbs])
+            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['has_non_empty_aabb_in_neighborhood_of_rank', self.has_non_empty_aabb_in_neighborhood_of_rank, self.nranks])
+
+        if self.sim._target.is_gpu():
+            CopyArray(self.sim, self.subdom, Contexts.Host, Actions.WriteOnly)
+
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
+        if isinstance(self.sim.grid, MutableGrid):
+            for d in range(self.sim.dims):
+                Assign(self.sim, self.sim.grid.min(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMin", [d], Types.Real))
+                Assign(self.sim, self.sim.grid.max(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMax", [d], Types.Real))
+
+    def update_neighborhood(self):
+        Call_Void(self.sim, "pairs_runtime->getDomainPartitioner()->updateNeighborhood", [])
         Assign(self.sim, self.rank, Call_Int(self.sim, "pairs_runtime->getDomainPartitioner()->getRank", []))
         Assign(self.sim, self.nranks, Call_Int(self.sim, "pairs_runtime->getNumberOfNeighborRanks", []))
-        Assign(self.sim, self.ntotal_aabbs, Call_Int(self.sim, "pairs_runtime->getNumberOfNeighborAABBs", []))
+        Assign(self.sim, self.total_num_neigh_aabbs, Call_Int(self.sim, "pairs_runtime->getNumberOfNeighborAABBs", []))
 
         for _ in Filter(self.sim, ScalarOp.neq(self.nranks, 0)):
             for _ in Filter(self.sim, self.nranks_capacity < self.nranks):
@@ -205,35 +337,28 @@ class BlockForest:
                 for arr in self.nranks_capacity.bonded_arrays():
                     arr.realloc()
 
-            for _ in Filter(self.sim, self.aabb_capacity < self.ntotal_aabbs):
-                Assign(self.sim, self.aabb_capacity, self.ntotal_aabbs + 20)
-                for arr in self.aabb_capacity.bonded_arrays():
+            for _ in Filter(self.sim, self.neigh_aabb_capacity < self.total_num_neigh_aabbs):
+                Assign(self.sim, self.neigh_aabb_capacity, self.total_num_neigh_aabbs + 20)
+                for arr in self.neigh_aabb_capacity.bonded_arrays():
                     arr.realloc()
 
             if self.sim._target.is_gpu():
                 CopyArray(self.sim, self.ranks, Contexts.Host, Actions.WriteOnly, self.nranks)
-                CopyArray(self.sim, self.naabbs, Contexts.Host, Actions.WriteOnly, self.nranks)
+                CopyArray(self.sim, self.num_neigh_aabbs, Contexts.Host, Actions.WriteOnly, self.nranks)
                 CopyArray(self.sim, self.aabb_offsets, Contexts.Host, Actions.WriteOnly, self.nranks)
-                CopyArray(self.sim, self.aabbs, Contexts.Host, Actions.WriteOnly, self.ntotal_aabbs * 6)
+                CopyArray(self.sim, self.neigh_aabbs, Contexts.Host, Actions.WriteOnly, self.total_num_neigh_aabbs * 6)
 
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['ranks', self.ranks, self.nranks])
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['naabbs', self.naabbs, self.nranks])
+            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['num_neigh_aabbs', self.num_neigh_aabbs, self.nranks])
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabb_offsets', self.aabb_offsets, self.nranks])
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabbs', self.aabbs, self.ntotal_aabbs * 6])
-        
-        if self.sim._target.is_gpu():
-            CopyArray(self.sim, self.subdom, Contexts.Host, Actions.WriteOnly)
-
-        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
-        
-        if isinstance(self.sim.grid, MutableGrid):
-            for d in range(self.sim.dims):
-                Assign(self.sim, self.sim.grid.min(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMin", [d], Types.Real))
-                Assign(self.sim, self.sim.grid.max(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMax", [d], Types.Real))
+            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['neigh_aabbs', self.neigh_aabbs, self.total_num_neigh_aabbs * 6])
+            
+    def rebalance(self):
+        Call_Void(self.sim, "pairs_runtime->getDomainPartitioner()->rebalance", [])
+        self.update_neighborhood()
 
     def ghost_particles(self, step, position, offset=0.0):
         if self.sim._use_halo_cells:
-            # No support for adaptive blocks yet
             yield from self.ghost_particles_halo_cells(step, position, offset)
         else:
             yield from self.ghost_particles_original(step, position, offset)
@@ -248,87 +373,19 @@ class BlockForest:
         flags_to_exclude = (Flags.Infinite | Flags.Global)
 
         for r in self.step_indexes(0):     # for every neighbor rank
-            for i in For(self.sim, 0, self.sim.nlocal):     # for every local particle in this rank
-                particle_flags = self.sim.particle_flags
-
-                for _ in Filter(self.sim, ScalarOp.cmp(particle_flags[i] & flags_to_exclude, 0)):
-                    for aabb_id in For(self.sim, self.aabb_offsets[r], self.aabb_offsets[r] + self.naabbs[r]): # for every aabb of this neighbor
-                        for _ in Filter(self.sim, ScalarOp.neq(self.ranks[r] , self.rank)):     # if my neighobr is not my own rank
-                            full_cond = None
-                            pbc_shifts = []
-
-                            for d in range(self.sim.ndims()):
-                                aabb_min = self.aabbs[aabb_id][d * 2 + 0]
-                                aabb_max = self.aabbs[aabb_id][d * 2 + 1]
-                                d_pbc = 0
-                                d_length = self.sim.grid.length(d)
-
-                                if self.sim._pbc[d]:
-                                    center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
-                                    dist = position[i][d] - center                      # distance of our particle from center of neighbor
-                                    cond_pbc_neg = dist >  (d_length * 0.5)
-                                    cond_pbc_pos = dist < -(d_length * 0.5)
-
-                                    d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
-
-                                adj_pos = position[i][d] + d_pbc * d_length 
-                                d_cond = ScalarOp.and_op(adj_pos > aabb_min - offset, adj_pos < aabb_max + offset)
-                                full_cond = d_cond if full_cond is None else ScalarOp.and_op(full_cond, d_cond)
-                                pbc_shifts.append(d_pbc)
-
-                            for _ in Filter(self.sim, full_cond):
-                                yield i, r, self.ranks[r], pbc_shifts
-
-                        for _ in Filter(self.sim, ScalarOp.cmp(self.ranks[r] , self.rank)):     # if my neighbor is me
-                            pbc_shifts = []
-                            isghost = Lit(self.sim, 0)
-
-                            for d in range(self.sim.ndims()):
-                                aabb_min = self.aabbs[aabb_id][d * 2 + 0]
-                                aabb_max = self.aabbs[aabb_id][d * 2 + 1]
-                                center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
-                                dist = position[i][d] - center                      # distance of our particle from center of neighbor
-                                d_pbc = 0
-                                d_length = self.sim.grid.length(d)
-
-                                if self.sim._pbc[d]:
-                                    cond_pbc_neg = dist >  (d_length*0.5 - offset)
-                                    cond_pbc_pos = dist < -(d_length*0.5 - offset)
-                                    d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
-                                    isghost = ScalarOp.or_op(isghost, d_pbc)
-
-                                pbc_shifts.append(d_pbc)
-                            
-                            for _ in Filter(self.sim, isghost):
-                                yield i, r, self.ranks[r], pbc_shifts
-
-
-    def ghost_particles_halo_cells(self, step, position, offset=0.0):
-        # Particles with one of the following flags are ignored
-        flags_to_exclude = (Flags.Infinite | Flags.Global)
-        cells_to_check = self.sim.cell_lists.halo_cells
-        ncells_to_check = self.sim.cell_lists.halo_ncells
-        for r in self.step_indexes(0):     # for every neighbor rank
-            for nc in For(self.sim, 0, ncells_to_check):
-                c = self.sim.add_temp_var(0)
-                Assign(self.sim, c, cells_to_check[nc])
-                for p in For(self.sim, 0, self.sim.cell_lists.cell_sizes[c]):
-                    i = self.sim.cell_lists.cell_particles[c][p]
+            for _ in Filter(self.sim, self.has_non_empty_aabb_in_neighborhood_of_rank[r]):
+                for i in For(self.sim, 0, self.sim.nlocal):     # for every local particle in this rank
                     particle_flags = self.sim.particle_flags
 
-                    # Skip ghost particles
-                    for _ in Filter(self.sim, i >= self.sim.nlocal):
-                        Continue(self.sim)()
-
                     for _ in Filter(self.sim, ScalarOp.cmp(particle_flags[i] & flags_to_exclude, 0)):
-                        for aabb_id in For(self.sim, self.aabb_offsets[r], self.aabb_offsets[r] + self.naabbs[r]): # for every aabb of this neighbor
+                        for aabb_id in For(self.sim, self.aabb_offsets[r], self.aabb_offsets[r] + self.num_neigh_aabbs[r]): # for every aabb of this neighbor
                             for _ in Filter(self.sim, ScalarOp.neq(self.ranks[r] , self.rank)):     # if my neighobr is not my own rank
                                 full_cond = None
                                 pbc_shifts = []
 
                                 for d in range(self.sim.ndims()):
-                                    aabb_min = self.aabbs[aabb_id][d * 2 + 0]
-                                    aabb_max = self.aabbs[aabb_id][d * 2 + 1]
+                                    aabb_min = self.neigh_aabbs[aabb_id][d * 2 + 0]
+                                    aabb_max = self.neigh_aabbs[aabb_id][d * 2 + 1]
                                     d_pbc = 0
                                     d_length = self.sim.grid.length(d)
 
@@ -353,8 +410,8 @@ class BlockForest:
                                 isghost = Lit(self.sim, 0)
 
                                 for d in range(self.sim.ndims()):
-                                    aabb_min = self.aabbs[aabb_id][d * 2 + 0]
-                                    aabb_max = self.aabbs[aabb_id][d * 2 + 1]
+                                    aabb_min = self.neigh_aabbs[aabb_id][d * 2 + 0]
+                                    aabb_max = self.neigh_aabbs[aabb_id][d * 2 + 1]
                                     center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
                                     dist = position[i][d] - center                      # distance of our particle from center of neighbor
                                     d_pbc = 0
@@ -370,3 +427,73 @@ class BlockForest:
                                 
                                 for _ in Filter(self.sim, isghost):
                                     yield i, r, self.ranks[r], pbc_shifts
+
+
+    def ghost_particles_halo_cells(self, step, position, offset=0.0):
+        # Particles with one of the following flags are ignored
+        flags_to_exclude = (Flags.Infinite | Flags.Global)
+        cells_to_check = self.sim.cell_lists.halo_cells
+        ncells_to_check = self.sim.cell_lists.halo_ncells
+        for r in self.step_indexes(0):     # for every neighbor rank
+            for _ in Filter(self.sim, self.has_non_empty_aabb_in_neighborhood_of_rank[r]):
+                for nc in For(self.sim, 0, ncells_to_check):
+                    c = self.sim.add_temp_var(0)
+                    Assign(self.sim, c, cells_to_check[nc])
+                    for p in For(self.sim, 0, self.sim.cell_lists.cell_sizes[c]):
+                        i = self.sim.cell_lists.cell_particles[c][p]
+                        particle_flags = self.sim.particle_flags
+
+                        # Skip ghost particles
+                        for _ in Filter(self.sim, i >= self.sim.nlocal):
+                            Continue(self.sim)()
+
+                        for _ in Filter(self.sim, ScalarOp.cmp(particle_flags[i] & flags_to_exclude, 0)):
+                            for aabb_id in For(self.sim, self.aabb_offsets[r], self.aabb_offsets[r] + self.num_neigh_aabbs[r]): # for every aabb of this neighbor
+                                for _ in Filter(self.sim, ScalarOp.neq(self.ranks[r] , self.rank)):     # if my neighobr is not my own rank
+                                    full_cond = None
+                                    pbc_shifts = []
+
+                                    for d in range(self.sim.ndims()):
+                                        aabb_min = self.neigh_aabbs[aabb_id][d * 2 + 0]
+                                        aabb_max = self.neigh_aabbs[aabb_id][d * 2 + 1]
+                                        d_pbc = 0
+                                        d_length = self.sim.grid.length(d)
+
+                                        if self.sim._pbc[d]:
+                                            center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
+                                            dist = position[i][d] - center                      # distance of our particle from center of neighbor
+                                            cond_pbc_neg = dist >  (d_length * 0.5)
+                                            cond_pbc_pos = dist < -(d_length * 0.5)
+
+                                            d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
+
+                                        adj_pos = position[i][d] + d_pbc * d_length 
+                                        d_cond = ScalarOp.and_op(adj_pos > aabb_min - offset, adj_pos < aabb_max + offset)
+                                        full_cond = d_cond if full_cond is None else ScalarOp.and_op(full_cond, d_cond)
+                                        pbc_shifts.append(d_pbc)
+
+                                    for _ in Filter(self.sim, full_cond):
+                                        yield i, r, self.ranks[r], pbc_shifts
+
+                                for _ in Filter(self.sim, ScalarOp.cmp(self.ranks[r] , self.rank)):     # if my neighbor is me
+                                    pbc_shifts = []
+                                    isghost = Lit(self.sim, 0)
+
+                                    for d in range(self.sim.ndims()):
+                                        aabb_min = self.neigh_aabbs[aabb_id][d * 2 + 0]
+                                        aabb_max = self.neigh_aabbs[aabb_id][d * 2 + 1]
+                                        center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
+                                        dist = position[i][d] - center                      # distance of our particle from center of neighbor
+                                        d_pbc = 0
+                                        d_length = self.sim.grid.length(d)
+
+                                        if self.sim._pbc[d]:
+                                            cond_pbc_neg = dist >  (d_length*0.5 - offset)
+                                            cond_pbc_pos = dist < -(d_length*0.5 - offset)
+                                            d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
+                                            isghost = ScalarOp.or_op(isghost, d_pbc)
+
+                                        pbc_shifts.append(d_pbc)
+                                    
+                                    for _ in Filter(self.sim, isghost):
+                                        yield i, r, self.ranks[r], pbc_shifts
diff --git a/src/pairs/sim/global_interaction.py b/src/pairs/sim/global_interaction.py
index dca2a775ee301def37a7d334fd9232d878aac42b..3bb6594f48f4ff38180f9fba61563b884177d9d1 100644
--- a/src/pairs/sim/global_interaction.py
+++ b/src/pairs/sim/global_interaction.py
@@ -28,20 +28,19 @@ class GlobalLocalInteraction(ParticleInteraction):
             Assign(self.sim, first_cell_bytes, self.cell_lists.cell_capacity * Sizeof(self.sim, Types.Int32))
             CopyArray(self.sim, self.cell_lists.cell_sizes, Contexts.Host, Actions.ReadOnly, first_cell_bytes)
         
-        for ishape in range(self.maxs):
+        for ishape in range(self.maxs): # shape of globals
             if self.include_shape(ishape):
-                # Loop over the global cell
-                for p in For(self.sim, 0, self.cell_lists.cell_sizes[0]):
-                    i = self.cell_lists.cell_particles[0][p]
-                    # TODO: Skip if the bounding box of the global body doesn't intersect the subdom of this rank
-                    for _ in Filter(self.sim, ScalarOp.and_op(
-                        ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
-                        self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global))):
-                        for jshape in range(self.maxs):
-                            if self.include_interaction(ishape, jshape):
-                                # Globals are presenet in all ranks so they should not interact with ghosts
-                                # TODO: Make this loop the kernel candiate and reduce forces on the global body
-                                for j in ParticleFor(self.sim):
+                for jshape in range(self.maxs): # shape of locals
+                    if self.include_interaction(ishape, jshape):
+                        # Globals are presenet in all ranks so they should not interact with ghosts
+                        for j in ParticleFor(self.sim):
+                            # Loop over the global cell
+                            for p in For(self.sim, 0, self.cell_lists.cell_sizes[0]):
+                                i = self.cell_lists.cell_particles[0][p]
+                                # TODO: Skip if the bounding box of the global body doesn't intersect the subdom of this rank
+                                for _ in Filter(self.sim, ScalarOp.and_op(
+                                    ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
+                                    self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global))):
                                     # Here we make make sure not to interact with other global bodies, otherwise
                                     # their contributions will get reduced again over all ranks
                                     for _ in Filter(self.sim, ScalarOp.and_op(
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index 901c2d4b7f1930b9841b25e0f57e5e143fea77e0..2671a907e526e1b9aefe0677349c5d2233f962f7 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -423,6 +423,20 @@ class ParticleInteraction(Lowerable):
         self.sim.module_name(f"{self.module_name}_local_interactions")
         if self.nbody == 2:
             neighbor_lists = None if self.use_cell_lists else self.sim.neighbor_lists
+            # for ishape in range(self.maxs):
+            #     if self.include_shape(ishape):
+            #         for jshape in range(self.maxs):
+            #             if self.include_interaction(ishape, jshape):
+            #                 # A kernel for each ishape
+            #                 for i in ParticleFor(self.sim):
+            #                     for _ in Filter(self.sim, ScalarOp.and_op(
+            #                         ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
+            #                         ScalarOp.not_op(self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global)))):
+            #                                 # Inner loops for each jshaped neighbor
+            #                                 for neigh in NeighborFor(self.sim, i, self.cell_lists, neighbor_lists, jshape):
+            #                                     j = neigh.particle_index()
+            #                                     self.compute_interaction(i, j, ishape, jshape)
+
             for ishape in range(self.maxs):
                 if self.include_shape(ishape):
                     # A kernel for each ishape
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 0d5f2f8e6e2d72090373092c296e6abe53ee7a96..0e1164afb29e257fd5e6a44dd70cef3254e7eeaa 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -260,11 +260,12 @@ class Simulation:
     def reneighbor_every(self, frequency):
         self.reneighbor_frequency = frequency
 
-    def build_cell_lists(self, spacing=None, store_neighbors_per_cell=False, use_halo_cells=False):
+    def build_cell_lists(self, spacing=None, store_neighbors_per_cell=False, use_halo_cells=False, optimize_halo_paddings=False):
         """Add routines to build the linked-cells acceleration structure.
         Leave spacing as None so it can be set at runtime."""
         self._store_neighbors_per_cell = store_neighbors_per_cell
         self._use_halo_cells = use_halo_cells
+        self._optimize_halo_paddings = optimize_halo_paddings
         self.cell_lists = CellLists(self, self._dom_part, spacing, spacing)
         return self.cell_lists