From 842850ea407c181ff170079ffd2b0d00e90d68e1 Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@a0226.nhr.fau.de>
Date: Mon, 24 Feb 2025 11:44:58 +0100
Subject: [PATCH] Fix dynamic load balancing (initial working version)

---
 CMakeLists.txt                                |   9 +-
 examples/modular/sd_4.cpp                     |  99 ++++++++++
 .../spring_dashpot.py                         |  24 +--
 runtime/boundary_weights.cpp                  | 121 ++----------
 runtime/boundary_weights.cu                   | 108 +++++++++++
 runtime/boundary_weights.hpp                  |  25 +--
 runtime/contact_property.hpp                  |   2 +-
 runtime/dem_sc_grid.cpp                       |   8 +-
 runtime/devices/cuda.cu                       |   7 -
 runtime/devices/device.hpp                    |   6 +
 runtime/domain/ParticleDataHandling.hpp       | 172 +++++++++++++----
 runtime/domain/block_forest.cpp               | 182 ++++++++++--------
 runtime/domain/block_forest.hpp               |   7 +-
 runtime/domain/domain_partitioning.hpp        |   3 +
 runtime/domain/regular_6d_stencil.cpp         |   2 +
 runtime/domain/regular_6d_stencil.hpp         |   3 +
 runtime/pairs.cpp                             |   6 +-
 runtime/pairs.hpp                             |  10 +-
 runtime/pairs_common.hpp                      |  31 ++-
 runtime/property.hpp                          |   4 +-
 runtime/unique_id.hpp                         |   5 +
 runtime/vtk.cpp                               | 138 +++++++++++--
 runtime/vtk.hpp                               |   9 +-
 src/pairs/__init__.py                         |  13 ++
 src/pairs/code_gen/cgen.py                    |  17 +-
 src/pairs/code_gen/interface.py               |  52 +++--
 src/pairs/ir/types.py                         |   8 +-
 src/pairs/sim/comm.py                         |   1 +
 src/pairs/sim/domain.py                       |  20 --
 src/pairs/sim/domain_partitioning.py          |  28 ++-
 src/pairs/sim/load_balancing_algorithms.py    |  13 ++
 src/pairs/sim/simulation.py                   |  77 +++++---
 32 files changed, 826 insertions(+), 384 deletions(-)
 create mode 100644 examples/modular/sd_4.cpp
 create mode 100644 runtime/boundary_weights.cu
 create mode 100644 src/pairs/sim/load_balancing_algorithms.py

diff --git a/CMakeLists.txt b/CMakeLists.txt
index ed0f398..27c9afc 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -109,8 +109,15 @@ endif()
 if(USE_WALBERLA)
     set(RUNTIME_WALBERLA_FILES
         runtime/domain/block_forest.cpp
-        # runtime/boundary_weights.cpp  # avoid compiling this for now. TODO: generate the host and device functions
     )
+
+    # TODO: Generate the host/device functions for computing weights
+    if(COMPILE_CUDA)
+        list(APPEND RUNTIME_WALBERLA_FILES runtime/boundary_weights.cu)
+    else()
+        list(APPEND RUNTIME_WALBERLA_FILES runtime/boundary_weights.cpp)
+    endif()
+
     target_sources(${PAIRS_TARGET} PRIVATE ${RUNTIME_WALBERLA_FILES})
     target_compile_definitions(${PAIRS_TARGET} PUBLIC USE_WALBERLA)
 
diff --git a/examples/modular/sd_4.cpp b/examples/modular/sd_4.cpp
new file mode 100644
index 0000000..be064d9
--- /dev/null
+++ b/examples/modular/sd_4.cpp
@@ -0,0 +1,99 @@
+#include <iostream>
+#include <memory>
+#include <iomanip>
+
+#include "spring_dashpot.hpp"
+
+void set_feature_properties(std::shared_ptr<PairsAccessor> &ac){
+    ac->setTypeStiffness(0,0, 100000);
+    ac->setTypeStiffness(0,1, 100000);
+    ac->setTypeStiffness(1,0, 100000);
+    ac->setTypeStiffness(1,1, 100000);
+    ac->syncTypeStiffness();
+
+    ac->setTypeDampingNorm(0,0, 300);
+    ac->setTypeDampingNorm(0,1, 300);
+    ac->setTypeDampingNorm(1,0, 300);
+    ac->setTypeDampingNorm(1,1, 300);
+    ac->syncTypeDampingNorm();
+
+    ac->setTypeFriction(0,0, 0.5);
+    ac->setTypeFriction(0,1, 0.5);
+    ac->setTypeFriction(1,0, 0.5);
+    ac->setTypeFriction(1,1, 0.5);
+    ac->syncTypeFriction();
+
+    ac->setTypeDampingTan(0,0, 20);
+    ac->setTypeDampingTan(0,1, 20);
+    ac->setTypeDampingTan(1,0, 20);
+    ac->setTypeDampingTan(1,1, 20);
+    ac->syncTypeDampingTan();
+}
+
+int main(int argc, char **argv) {
+    auto pairs_sim = std::make_shared<PairsSimulation>();
+    pairs_sim->initialize();
+
+    auto ac = std::make_shared<PairsAccessor>(pairs_sim.get());
+    set_feature_properties(ac);
+
+    auto pairs_runtime = pairs_sim->getPairsRuntime();
+
+    pairs_runtime->initDomain(&argc, &argv, 
+                    0, 20, 0, 20, 0, 20,    // Domain bounds
+                    false, false, false,    // PBCs --------------> TODO: runtime pbc
+                    true                    // Enable dynamic load balancing (does initial refinement on a <1,1,1> blockforest)
+                ); 
+
+    pairs_runtime->getDomainPartitioner()->initWorkloadBalancer(pairs::Hilbert, 100, 1000);
+
+    pairs::create_halfspace(pairs_runtime, 0,0,0,  1, 0, 0,     0, 13);
+    pairs::create_halfspace(pairs_runtime, 0,0,0,  0, 1, 0,     0, 13);
+    pairs::create_halfspace(pairs_runtime, 0,0,0,  0, 0, 1,     0, 13);
+    pairs::create_halfspace(pairs_runtime, 20,20,20,  -1, 0, 0,    0, 13);
+    pairs::create_halfspace(pairs_runtime, 20,20,20,  0, -1, 0,    0, 13);
+    pairs::create_halfspace(pairs_runtime, 20,20,20,  0, 0, -1,    0, 13);
+
+    double diameter_min = 0.3;
+    double diameter_max = 0.3;
+    double sphere_spacing = 0.4;
+    pairs::dem_sc_grid(pairs_runtime, 10, 10, 15,  sphere_spacing, diameter_min, diameter_min, diameter_max,    2,      100,    2);
+    
+    double lcw = diameter_max * 1.01;       // Linked-cell width
+    double interaction_radius = diameter_max;
+    pairs_sim->setup_sim(lcw, lcw, lcw, interaction_radius);
+
+    pairs_sim->update_mass_and_inertia();
+
+    int num_timesteps = 4000;
+    int vtk_freq = 20;
+    int rebalance_freq = 200;
+    double dt = 1e-3;
+
+    pairs::vtk_write_subdom(pairs_runtime, "output/subdom_init", 0);
+
+    
+    for (int t=0; t<num_timesteps; ++t){
+        if ((t % vtk_freq==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
+        
+        if (t % rebalance_freq == 0){ 
+            pairs_sim->update_domain();
+        }
+        
+        pairs_sim->update_cells(t); 
+        
+        pairs_sim->gravity(); 
+        pairs_sim->spring_dashpot(); 
+        pairs_sim->euler(dt); 
+        
+        pairs_sim->communicate(t);
+
+        if (t % vtk_freq==0){
+            pairs::vtk_write_subdom(pairs_runtime, "output/subdom", t);
+            pairs::vtk_write_data(pairs_runtime, "output/sd_4_local", 0, pairs_sim->nlocal(), t);
+            pairs::vtk_write_data(pairs_runtime, "output/sd_4_ghost", pairs_sim->nlocal(), pairs_sim->size(), t);
+        }
+    }
+
+    pairs_sim->end();
+}
\ No newline at end of file
diff --git a/examples/whole-program-generation/spring_dashpot.py b/examples/whole-program-generation/spring_dashpot.py
index 397b542..c12ec34 100644
--- a/examples/whole-program-generation/spring_dashpot.py
+++ b/examples/whole-program-generation/spring_dashpot.py
@@ -51,21 +51,21 @@ def gravity(i):
 
 # Number of 'type' features and their pair-wise properties
 ntypes = 2
-stiffness_SI = [5000 for i in range(ntypes * ntypes)]
-dampingNorm_SI = [20 for i in range(ntypes * ntypes)]
-dampingTan_SI = [20 for i in range(ntypes * ntypes)]
-friction_SI = [0.2 for i in range(ntypes * ntypes)]
+stiffness_SI = [100000 for i in range(ntypes * ntypes)]
+dampingNorm_SI = [300 for i in range(ntypes * ntypes)]
+dampingTan_SI = [0.0 for i in range(ntypes * ntypes)]
+friction_SI = [0.0 for i in range(ntypes * ntypes)]
 
 # Domain size
-domainSize_SI=[1, 1, 1]
+domainSize_SI=[10, 10, 10]
 
 # Parameters required for generating the initial grid of particles 'dem_sc_grid'
-generationSpacing_SI = 0.1
-diameter_SI = 0.09
+generationSpacing_SI = 0.4
+diameter_SI = 0.3
 minDiameter_SI = diameter_SI
 maxDiameter_SI = diameter_SI
-initialVelocity_SI = 0.5
-densityParticle_SI = 1000
+initialVelocity_SI = 2
+densityParticle_SI = 100
 
 # Linked cell width 
 linkedCellWidth = 1.01 * maxDiameter_SI
@@ -77,9 +77,9 @@ gravity_SI = 9.81
 dt_SI = 1e-3
 
 # VTK frequency
-visSpacing = 60
+visSpacing = 20
 
-timeSteps = 6000
+timeSteps = 2000
 
 file_name = os.path.basename(__file__)
 file_name_without_extension = os.path.splitext(file_name)[0]
@@ -127,6 +127,8 @@ psim.build_cell_lists(linkedCellWidth)
 
 psim.set_domain([0.0, 0.0, 0.0, domainSize_SI[0], domainSize_SI[1], domainSize_SI[2]])
 
+psim.set_workload_balancer(pairs.morton(), regrid_min=100, regrid_max=1000, rebalance_frequency=200)
+
 # Generate particles
 psim.dem_sc_grid(domainSize_SI[0], domainSize_SI[1], domainSize_SI[2], 
                  generationSpacing_SI,
diff --git a/runtime/boundary_weights.cpp b/runtime/boundary_weights.cpp
index 2174da8..3a67d29 100644
--- a/runtime/boundary_weights.cpp
+++ b/runtime/boundary_weights.cpp
@@ -1,86 +1,7 @@
-#include <iostream>
-#include <string.h>
-#include <fstream>
-#include <sstream>
-//---
 #include "boundary_weights.hpp"
-#include "pairs.hpp"
-#include "pairs_common.hpp"
 
 // Always include last generated interfaces
-#include "interfaces/last_generated.hpp"
-
-#ifdef PAIRS_TARGET_CUDA
-
-#define REDUCE_BLOCK_SIZE   64
-
-__global__ void reduceBoundaryWeights(
-    real_t *position, int start, int end, int particle_capacity,
-    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, int *d_weights) {
-
-    __shared__ int red_data[REDUCE_BLOCK_SIZE];
-    int tid = threadIdx.x;
-    int i = blockIdx.x * blockDim.x + tid;
-    int particle_idx = start + i;
-
-    red_data[tid] = 0;
-
-    if(particle_idx < end) {
-        real_t pos_x = pairs_cuda_interface::get_position(position, particle_idx, 0, particle_capacity);
-        real_t pos_y = pairs_cuda_interface::get_position(position, particle_idx, 1, particle_capacity);
-        real_t pos_z = pairs_cuda_interface::get_position(position, particle_idx, 2, particle_capacity);
-
-        if( pos_x > xmin && pos_x <= xmax &&
-            pos_y > ymin && pos_y <= ymax &&
-            pos_z > zmin && pos_z <= zmax) {
-                red_data[tid] = 1;
-        }
-    }
-
-    __syncthreads();
-
-    int s = blockDim.x >> 1;
-    while(s > 0) {
-        if(tid < s) {
-            red_data[tid] += red_data[tid + s];
-        }
-
-        __syncthreads();
-        s >>= 1;
-    }
-
-    if(tid == 0) {
-        d_weights[blockIdx.x] = red_data[0];
-    }
-}
-
-int cuda_compute_boundary_weights(
-    real_t *position, int start, int end, int particle_capacity,
-    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
-
-    const int nblocks = (end - start + (REDUCE_BLOCK_SIZE - 1)) / REDUCE_BLOCK_SIZE;
-    int *h_weights = (int *) malloc(nblocks * sizeof(int));
-    int *d_weights = (int *) device_alloc(nblocks * sizeof(int));
-    int red = 0;
-
-    CUDA_ASSERT(cudaMemset(d_weights, 0, nblocks * sizeof(int)));
-
-    reduceBoundaryWeights<<<nblocks, REDUCE_BLOCK_SIZE>>>(
-            position, start, end, particle_capacity,
-            xmin, xmax, ymin, ymax, zmin, zmax, d_weights);
-
-    CUDA_ASSERT(cudaPeekAtLastError());
-    CUDA_ASSERT(cudaDeviceSynchronize());
-    CUDA_ASSERT(cudaMemcpy(h_weights, d_weights, nblocks * sizeof(int), cudaMemcpyDeviceToHost));
-
-    for(int i = 0; i < nblocks; i++) {
-        red += h_weights[i];
-    }
-
-    return red;
-}
-#endif
-
+#include "last_generated.hpp"
 namespace pairs {
 
 void compute_boundary_weights(
@@ -90,50 +11,36 @@ void compute_boundary_weights(
 
     const int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
     const int nlocal = ps->getTrackedVariableAsInteger("nlocal");
-    const int nghost = ps->getTrackedVariableAsInteger("nghost");
     auto position_prop = ps->getPropertyByName("position");
+    auto flags_prop = ps->getPropertyByName("flags");
 
-    #ifndef PAIRS_TARGET_CUDA
     real_t *position_ptr = static_cast<real_t *>(position_prop.getHostPointer());
+    int *flags_ptr = static_cast<int *>(flags_prop.getHostPointer());
 
     *comp_weight = 0;
-    *comm_weight = 0;
 
     for(int i = 0; i < nlocal; i++) {
-        real_t pos_x = pairs_host_interface::get_position(position_ptr, i, 0, particle_capacity);
-        real_t pos_y = pairs_host_interface::get_position(position_ptr, i, 1, particle_capacity);
-        real_t pos_z = pairs_host_interface::get_position(position_ptr, i, 2, particle_capacity);
-
-        if( pos_x > xmin && pos_x <= xmax &&
-            pos_y > ymin && pos_y <= ymax &&
-            pos_z > zmin && pos_z <= zmax) {
-                (*comp_weight)++;
+        if (pairs_host_interface::get_flags(flags_ptr, i) & (pairs::flags::INFINITE | pairs::flags::GLOBAL)) {
+            continue;
         }
-    }
 
-    for(int i = nlocal; i < nlocal + nghost; i++) {
         real_t pos_x = pairs_host_interface::get_position(position_ptr, i, 0, particle_capacity);
         real_t pos_y = pairs_host_interface::get_position(position_ptr, i, 1, particle_capacity);
         real_t pos_z = pairs_host_interface::get_position(position_ptr, i, 2, particle_capacity);
 
-        if( pos_x > xmin && pos_x <= xmax &&
-            pos_y > ymin && pos_y <= ymax &&
-            pos_z > zmin && pos_z <= zmax) {
-                (*comm_weight)++;
+        if( pos_x >= xmin && pos_x < xmax &&
+            pos_y >= ymin && pos_y < ymax &&
+            pos_z >= zmin && pos_z < zmax) {
+                (*comp_weight)++;
         }
     }
-    // std::cout << "comp_weight = " << (*comp_weight) << ", comm_weight = " << (*comm_weight) << std::endl;
-    #else
-    real_t *position_ptr = static_cast<real_t *>(position_prop.getDevicePointer());
 
-    ps->copyPropertyToDevice(position_prop.getId(), ReadOnly);
-
-    *comp_weight = cuda_compute_boundary_weights(
-        position_ptr, 0, nlocal, particle_capacity, xmin, xmax, ymin, ymax, zmin, zmax);
+    // TODO: Count the number of ghosts that must be communicated with this block.
+    // Note: The ghosts stored in this rank are NOT contained in the aabb of any of its blocks.
+    //       And neighbor blocks are going to change after rebalancing.
+    // const int nghost = ps->getTrackedVariableAsInteger("nghost");
+    *comm_weight = 0;
 
-    *comm_weight = cuda_compute_boundary_weights(
-        position_ptr, nlocal, nlocal + nghost, particle_capacity, xmin, xmax, ymin, ymax, zmin, zmax);
-    #endif
 }
 
 }
diff --git a/runtime/boundary_weights.cu b/runtime/boundary_weights.cu
new file mode 100644
index 0000000..191139f
--- /dev/null
+++ b/runtime/boundary_weights.cu
@@ -0,0 +1,108 @@
+#include "boundary_weights.hpp"
+// #include "devices/device.hpp"
+
+// Always include last generated interfaces
+#include "last_generated.hpp"
+#define CUDA_ASSERT(a) { pairs::cuda_assert((a), __FILE__, __LINE__); }
+
+namespace pairs {
+
+#define REDUCE_BLOCK_SIZE 64
+
+__global__ void reduceBoundaryWeights( real_t *position, int *flags, int start, int end, int particle_capacity,
+    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, int *d_weights) {
+
+    __shared__ int red_data[REDUCE_BLOCK_SIZE];
+    int tid = threadIdx.x;
+    int i = blockIdx.x * blockDim.x + tid;
+    int particle_idx = start + i;
+
+    red_data[tid] = 0;
+
+    if(particle_idx < end) {
+        if (!(pairs_cuda_interface::get_flags(flags, i) & (pairs::flags::INFINITE | pairs::flags::GLOBAL))) {
+
+            real_t pos_x = pairs_cuda_interface::get_position(position, particle_idx, 0, particle_capacity);
+            real_t pos_y = pairs_cuda_interface::get_position(position, particle_idx, 1, particle_capacity);
+            real_t pos_z = pairs_cuda_interface::get_position(position, particle_idx, 2, particle_capacity);
+
+            if( pos_x >= xmin && pos_x < xmax &&
+                pos_y >= ymin && pos_y < ymax &&
+                pos_z >= zmin && pos_z < zmax) {
+                    red_data[tid] = 1;
+            }
+        }
+    }
+
+    __syncthreads();
+
+    int s = blockDim.x >> 1;
+    while(s > 0) {
+        if(tid < s) {
+            red_data[tid] += red_data[tid + s];
+        }
+
+        __syncthreads();
+        s >>= 1;
+    }
+
+    if(tid == 0) {
+        d_weights[blockIdx.x] = red_data[0];
+    }
+}
+
+int cuda_compute_boundary_weights(
+    real_t *position, int *flags, int start, int end, int particle_capacity,
+    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
+    
+    if (start==end) return 0;
+    const int nblocks = (end - start + (REDUCE_BLOCK_SIZE - 1)) / REDUCE_BLOCK_SIZE;
+
+    int *h_weights = (int *) malloc(nblocks * sizeof(int));
+    int *d_weights = (int *) device_alloc(nblocks * sizeof(int));
+    int red = 0;
+
+    CUDA_ASSERT(cudaMemset(d_weights, 0, nblocks * sizeof(int)));
+    reduceBoundaryWeights<<<nblocks, REDUCE_BLOCK_SIZE>>>(
+            position, flags, start, end, particle_capacity,
+            xmin, xmax, ymin, ymax, zmin, zmax, d_weights);
+
+    CUDA_ASSERT(cudaPeekAtLastError());
+    CUDA_ASSERT(cudaDeviceSynchronize());
+    CUDA_ASSERT(cudaMemcpy(h_weights, d_weights, nblocks * sizeof(int), cudaMemcpyDeviceToHost));
+
+    for(int i = 0; i < nblocks; i++) {
+        red += h_weights[i];
+    }
+
+    return red;
+}
+
+void compute_boundary_weights(
+    PairsRuntime *ps,
+    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax,
+    long unsigned int *comp_weight, long unsigned int *comm_weight) {
+
+    const int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
+    const int nlocal = ps->getTrackedVariableAsInteger("nlocal");
+    const int nghost = ps->getTrackedVariableAsInteger("nghost");
+    auto position_prop = ps->getPropertyByName("position");
+    auto flags_prop = ps->getPropertyByName("flags");
+
+
+    real_t *position_ptr = static_cast<real_t *>(position_prop.getDevicePointer());
+    int *flags_ptr = static_cast<int *>(flags_prop.getDevicePointer());
+
+    ps->copyPropertyToDevice(position_prop.getId(), ReadOnly);
+    ps->copyPropertyToDevice(flags_prop.getId(), ReadOnly);
+
+    *comp_weight = cuda_compute_boundary_weights(
+        position_ptr, flags_ptr, 0, nlocal, particle_capacity, xmin, xmax, ymin, ymax, zmin, zmax);
+
+    // TODO
+    // *comm_weight = cuda_compute_boundary_weights(
+    //     position_ptr, nlocal, nlocal + nghost, particle_capacity, xmin, xmax, ymin, ymax, zmin, zmax);
+    *comm_weight = 0;
+}
+
+}
diff --git a/runtime/boundary_weights.hpp b/runtime/boundary_weights.hpp
index bb310ad..e84348a 100644
--- a/runtime/boundary_weights.hpp
+++ b/runtime/boundary_weights.hpp
@@ -1,25 +1,20 @@
-#include "pairs.hpp"
+#pragma once
 
-/*
-#define INTERFACE_DIR "interfaces/"
-#define INTERFACE_EXT ".hpp"
-#define INTERFACE_FILE(a, b, c) a ## b ## c
-#define INCLUDE_FILE(filename) #filename
-#include INCLUDE_FILE(INTERFACE_FILE(INTERFACE_DIR, APPLICATION_REFERENCE, INTERFACE_EXT))
-*/
+#include "pairs.hpp"
+#include <iostream>
+#include <string.h>
+#include <fstream>
+#include <sstream>
+//---
+#include "pairs.hpp"
+#include "pairs_common.hpp"
 
-#pragma once
 
 namespace pairs {
 
 void compute_boundary_weights(
     PairsRuntime *ps,
     real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax,
-    long unsigned int *comp_weight, long unsigned int *comm_weight)
-    {
-        std::cerr<< "TODO: boundary weights should be generated" << std::endl;
-        exit(-1);
-    }
-    ;
+    long unsigned int *comp_weight, long unsigned int *comm_weight);
 
 }
diff --git a/runtime/contact_property.hpp b/runtime/contact_property.hpp
index be31b51..a64992f 100644
--- a/runtime/contact_property.hpp
+++ b/runtime/contact_property.hpp
@@ -37,7 +37,7 @@ public:
     layout_t getLayout() const { return layout; }
     size_t getPrimitiveTypeSize() const {
         return  (type == Prop_Integer) ? sizeof(int) :
-                (type == Prop_UInt64) ? sizeof(unsigned long long int) :
+                (type == Prop_UInt64) ? sizeof(uint64_t) :
                 (type == Prop_Real) ? sizeof(real_t) :
                 (type == Prop_Vector) ? sizeof(real_t) :
                 (type == Prop_Matrix) ? sizeof(real_t) :
diff --git a/runtime/dem_sc_grid.cpp b/runtime/dem_sc_grid.cpp
index 863f826..119ec78 100644
--- a/runtime/dem_sc_grid.cpp
+++ b/runtime/dem_sc_grid.cpp
@@ -34,7 +34,7 @@ int dem_sc_grid(PairsRuntime *ps, double xmax, double ymax, double zmax, double
 
     const double xmin = 0.0;
     const double ymin = 0.0;
-    const double zmin = diameter;
+    const double zmin = 0.0;
 
     double gen_domain[] = {xmin, ymin, zmin, xmax, ymax, zmax};
     double ref_point[] = {spacing * 0.5, spacing * 0.5, spacing * 0.5};
@@ -68,10 +68,10 @@ int dem_sc_grid(PairsRuntime *ps, double xmax, double ymax, double zmax, double
             positions(nparticles, 2) = point[2];
             velocities(nparticles, 0) = 0.1 * realRandom<real_t>(-initial_velocity, initial_velocity);
             velocities(nparticles, 1) = 0.1 * realRandom<real_t>(-initial_velocity, initial_velocity);
-            velocities(nparticles, 2) = -initial_velocity;
+            velocities(nparticles, 2) = 0.1 * realRandom<real_t>(-initial_velocity, initial_velocity);
             types(nparticles) = rand() % ntypes;
             flags(nparticles) = 0;
-            shapes(nparticles) = 0; // sphere
+            shapes(nparticles) = shapes::Sphere;
 
             /*
             std::cout << uid(nparticles) << "," << types(nparticles) << "," << masses(nparticles) << "," << radius(nparticles) << ","
@@ -109,6 +109,8 @@ int dem_sc_grid(PairsRuntime *ps, double xmax, double ymax, double zmax, double
         }
     }
 
+    ps->setTrackedVariableAsInteger("nlocal", nparticles);
+
     int global_nparticles = nparticles;
     if(ps->getDomainPartitioner()->getWorldSize() > 1) {
         MPI_Allreduce(&nparticles, &global_nparticles, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
diff --git a/runtime/devices/cuda.cu b/runtime/devices/cuda.cu
index 4caad6d..2cae5aa 100644
--- a/runtime/devices/cuda.cu
+++ b/runtime/devices/cuda.cu
@@ -8,13 +8,6 @@
 
 namespace pairs {
 
-inline void cuda_assert(cudaError_t err, const char *file, int line) {
-    if(err != cudaSuccess) {
-        std::cerr << file << ":" << line << ": " << cudaGetErrorString(err) << std::endl;
-        exit(-1);
-    }
-}
-
 __host__ void *device_alloc(size_t size) {
     void *ptr;
     CUDA_ASSERT(cudaMalloc(&ptr, size));
diff --git a/runtime/devices/device.hpp b/runtime/devices/device.hpp
index 60656d4..c5c406e 100644
--- a/runtime/devices/device.hpp
+++ b/runtime/devices/device.hpp
@@ -73,6 +73,12 @@ inline __host__ int host_atomic_add_resize_check(int *addr, int val, int *resize
 }
 
 #ifdef PAIRS_TARGET_CUDA
+inline void cuda_assert(cudaError_t err, const char *file, int line) {
+    if(err != cudaSuccess) {
+        std::cerr << file << ":" << line << ": " << cudaGetErrorString(err) << std::endl;
+        exit(-1);
+    }
+}
 __device__ double atomicAdd_double(double* address, double val);
 __device__ int atomic_add(int *addr, int val);
 __device__ real_t atomic_add(real_t *addr, real_t val);
diff --git a/runtime/domain/ParticleDataHandling.hpp b/runtime/domain/ParticleDataHandling.hpp
index f66bb70..cf467d9 100644
--- a/runtime/domain/ParticleDataHandling.hpp
+++ b/runtime/domain/ParticleDataHandling.hpp
@@ -7,6 +7,49 @@ namespace pairs {
 
 class PairsRuntime;
 
+void relocate_particle(PairsRuntime *ps, int dst, int src){
+    for(auto &prop: ps->getProperties()) {
+        if(!prop.isVolatile()) {
+            auto prop_type = prop.getType();
+
+            if(prop_type == pairs::Prop_Vector) {
+                auto vector_ptr = ps->getAsVectorProperty(prop);
+                constexpr int nelems = 3;
+
+                for(int e = 0; e < nelems; e++) {
+                    vector_ptr(dst, e) = vector_ptr(src, e);
+                }
+            } else if(prop_type == pairs::Prop_Matrix) {
+                auto matrix_ptr = ps->getAsMatrixProperty(prop);
+                constexpr int nelems = 9;
+
+                for(int e = 0; e < nelems; e++) {
+                    matrix_ptr(dst, e) = matrix_ptr(src, e);
+                }
+            } else if(prop_type == pairs::Prop_Quaternion) {
+                auto quat_ptr = ps->getAsQuaternionProperty(prop);
+                constexpr int nelems = 4;
+
+                for(int e = 0; e < nelems; e++) {
+                    quat_ptr(dst, e) = quat_ptr(src, e);
+                }
+            } else if(prop_type == pairs::Prop_Integer) {
+                auto int_ptr = ps->getAsIntegerProperty(prop);
+                int_ptr(dst) = int_ptr(src);
+            } else if(prop_type == pairs::Prop_UInt64) {
+                auto uint64_ptr = ps->getAsUInt64Property(prop);
+                uint64_ptr(dst) = uint64_ptr(src);
+            } else if(prop_type == pairs::Prop_Real) {
+                auto float_ptr = ps->getAsFloatProperty(prop);
+                float_ptr(dst) = float_ptr(src);
+            } else {
+                std::cerr << "relocate_particle(): Invalid property type!" << std::endl;
+                return;
+            }
+        }
+    }
+}
+
 }
 
 namespace walberla {
@@ -17,15 +60,56 @@ class ParticleDeleter {
     friend bool operator==(const ParticleDeleter& lhs, const ParticleDeleter& rhs);
 
 public:
-    ParticleDeleter(const math::AABB& aabb) : aabb_(aabb) {}
-    ~ParticleDeleter() {}
+    ParticleDeleter(pairs::PairsRuntime *ps_, const math::AABB& aabb_) : ps(ps_), aabb(aabb_) {}
+
+    ~ParticleDeleter() {
+        int nlocal = ps->getTrackedVariableAsInteger("nlocal");
+        auto position = ps->getAsVectorProperty(ps->getPropertyByName("position"));
+        auto flags = ps->getAsIntegerProperty(ps->getPropertyByName("flags"));
+
+        int ndeleted = 0;
+        int *goneIdx = new int[nlocal];
+        
+        for (int i=0; i<nlocal; ++i) {
+            if (flags(i) & (pairs::flags::INFINITE | pairs::flags::GLOBAL))  continue;
+
+            const real_t pos_x = position(i, 0);
+            const real_t pos_y = position(i, 1);
+            const real_t pos_z = position(i, 2);
+
+            if( aabb.contains(pos_x, pos_y, pos_z)) {
+                goneIdx[ndeleted] = i;
+                ++ndeleted;
+            }
+        }
+
+        int beg = 0;
+        int end = ndeleted - 1;
+        int i = nlocal - 1;
+        while ((i > goneIdx[beg]) && (beg <= end)) {
+            if(i == goneIdx[end]){
+                --end;
+            }
+            else{
+                pairs::relocate_particle(ps, goneIdx[beg], i);
+                ++beg;
+            }
+            --i;
+        }
+        
+        delete[] goneIdx;
+        
+        ps->setTrackedVariableAsInteger("nlocal", nlocal - ndeleted);
+        ps->setTrackedVariableAsInteger("nghost", 0);
+    }
 
 private:
-    math::AABB aabb_;
+    pairs::PairsRuntime *ps;
+    math::AABB aabb;
 };
 
 inline bool operator==(const ParticleDeleter& lhs, const ParticleDeleter& rhs) {
-    return lhs.aabb_ == rhs.aabb_;
+    return lhs.aabb == rhs.aabb;
 }
 
 } // namespace internal
@@ -39,7 +123,7 @@ public:
     ~ParticleDataHandling() override = default;
 
     internal::ParticleDeleter *initialize(IBlock *const block) override {
-        return new internal::ParticleDeleter(block->getAABB());
+        return new internal::ParticleDeleter(ps, block->getAABB());
     }
 
     void serialize(IBlock *const block, const BlockDataID& id, mpi::SendBuffer& buffer) override {
@@ -101,28 +185,30 @@ public:
             aabb_check[5] = aabb.zMax();
         }
 
-        for(auto& prop: ps->getProperties()) {
-            if(!prop.isVolatile()) {
-                ps->copyPropertyToHost(prop, pairs::WriteAfterRead);
-            }
-        }
-
-        auto position = ps->getAsVectorProperty(ps->getPropertyByName("position"));
         int nlocal = ps->getTrackedVariableAsInteger("nlocal");
-        int i = 0;
+        auto position = ps->getAsVectorProperty(ps->getPropertyByName("position"));
+        auto flags = ps->getAsIntegerProperty(ps->getPropertyByName("flags"));
         int nserialized = 0;
+        int *goneIdx = new int[nlocal];
 
-        while(i < nlocal) {
+        for (int i=0; i<nlocal; ++i) {
+            if (flags(i) & (pairs::flags::INFINITE | pairs::flags::GLOBAL)) continue;
             const real_t pos_x = position(i, 0);
             const real_t pos_y = position(i, 1);
             const real_t pos_z = position(i, 2);
 
-            if( pos_x > aabb_check[0] && pos_x <= aabb_check[1] &&
-                pos_y > aabb_check[2] && pos_y <= aabb_check[3] &&
-                pos_z > aabb_check[4] && pos_z <= aabb_check[5]) {
-
-                nlocal--;
-
+            // Important: When rebalancing, it is assumed that all particles are within domain bounds.  
+            // If a particle's center of mass lies outside the domain, it won't be contained
+            // in any of the checked blocks during serialization. In that case, the particle  
+            // can become disassociated from its owner if the new block it should belong to is  
+            // not an immediate neighbor to its owner rank. (if it's in an immediate neighbor, it will be exchanged)
+            if( pos_x >= aabb_check[0] && pos_x < aabb_check[1] &&
+                pos_y >= aabb_check[2] && pos_y < aabb_check[3] &&
+                pos_z >= aabb_check[4] && pos_z < aabb_check[5]) {
+
+                goneIdx[nserialized] = i;
+                ++nserialized;
+                
                 for(auto &prop: ps->getProperties()) {
                     if(!prop.isVolatile()) {
                         auto prop_type = prop.getType();
@@ -133,7 +219,6 @@ public:
 
                             for(int e = 0; e < nelems; e++) {
                                 buffer << vector_ptr(i, e);
-                                vector_ptr(i, e) = vector_ptr(nlocal, e);
                             }
                         } else if(prop_type == pairs::Prop_Matrix) {
                             auto matrix_ptr = ps->getAsMatrixProperty(prop);
@@ -141,7 +226,6 @@ public:
 
                             for(int e = 0; e < nelems; e++) {
                                 buffer << matrix_ptr(i, e);
-                                matrix_ptr(i, e) = matrix_ptr(nlocal, e);
                             }
                         } else if(prop_type == pairs::Prop_Quaternion) {
                             auto quat_ptr = ps->getAsQuaternionProperty(prop);
@@ -149,33 +233,48 @@ public:
 
                             for(int e = 0; e < nelems; e++) {
                                 buffer << quat_ptr(i, e);
-                                quat_ptr(i, e) = quat_ptr(nlocal, e);
                             }
                         } else if(prop_type == pairs::Prop_Integer) {
                             auto int_ptr = ps->getAsIntegerProperty(prop);
-                            buffer << int_ptr(i);
-                            int_ptr(i) = int_ptr(nlocal);
+                                buffer << int_ptr(i);
                         } else if(prop_type == pairs::Prop_UInt64) {
                             auto uint64_ptr = ps->getAsUInt64Property(prop);
-                            buffer << uint64_ptr(i);
-                            uint64_ptr(i) = uint64_ptr(nlocal);
+                                buffer << uint64_ptr(i);
                         } else if(prop_type == pairs::Prop_Real) {
                             auto float_ptr = ps->getAsFloatProperty(prop);
-                            buffer << float_ptr(i);
-                            float_ptr(i) = float_ptr(nlocal);
+                                buffer << float_ptr(i);
                         } else {
                             std::cerr << "serializeImpl(): Invalid property type!" << std::endl;
                             return;
                         }
                     }
                 }
+                // TODO: serialize contact history data as well
             }
+        }
 
-            // TODO: serialize contact history data as well
-            nserialized++;
+        // Here we replace serialized particles with the remaining locals 
+        // (Traverse locals in reverse order and move them to empty slots)
+        // Ghosts are ignored since they become invalid after rebalancing
+        int beg = 0;
+        int end = nserialized - 1;
+        int i = nlocal - 1;
+        while ((i > goneIdx[beg]) && (beg <= end)) {
+            if(i == goneIdx[end]){
+                --end;
+            }
+            else{
+                pairs::relocate_particle(ps, goneIdx[beg], i);
+                ++beg;
+            }
+            --i;
         }
 
-        ps->setTrackedVariableAsInteger("nlocal", nlocal);
+        delete[] goneIdx;
+
+        ps->setTrackedVariableAsInteger("nlocal", nlocal - nserialized);
+        ps->setTrackedVariableAsInteger("nghost", 0);
+        
         *ptr = (uint_t) nserialized;
     }
 
@@ -185,13 +284,13 @@ public:
         real_t real_tmp;
         int int_tmp;
         uint_t nrecv;
-        unsigned long long int uint64_tmp;
+        uint64_t uint64_tmp;
 
         buffer >> nrecv;
-
+        
         // TODO: Check if there is enough particle capacity for the new particles, when there is not,
         // all properties and arrays which have particle_capacity as one of their dimensions must be reallocated
-        // PAIRS_ASSERT(nlocal + nrecv < particle_capacity);
+        PAIRS_ASSERT(nlocal + nrecv < particle_capacity);
 
         for(int i = 0; i < nrecv; ++i) {
             for(auto &prop: ps->getProperties()) {
@@ -241,8 +340,9 @@ public:
                 }
             }
         }
-
+        
         ps->setTrackedVariableAsInteger("nlocal", nlocal + nrecv);
+        ps->setTrackedVariableAsInteger("nghost", 0);
     }
 };
 
diff --git a/runtime/domain/block_forest.cpp b/runtime/domain/block_forest.cpp
index fcc5329..0851f2c 100644
--- a/runtime/domain/block_forest.cpp
+++ b/runtime/domain/block_forest.cpp
@@ -18,16 +18,17 @@
 #include "../devices/device.hpp"
 #include "regular_6d_stencil.hpp"
 #include "ParticleDataHandling.hpp"
+#include "../unique_id.hpp"
 
 namespace pairs {
 
 BlockForest::BlockForest(
         PairsRuntime *ps_,
-        real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, bool pbcx, bool pbcy, bool pbcz) :
-        DomainPartitioner(xmin, xmax, ymin, ymax, zmin, zmax), ps(ps_), globalPBC{pbcx, pbcy, pbcz} {
+        real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, bool pbcx, bool pbcy, bool pbcz, bool balance_workload_) :
+        DomainPartitioner(xmin, xmax, ymin, ymax, zmin, zmax), ps(ps_), globalPBC{pbcx, pbcy, pbcz}, balance_workload(balance_workload_) {
 
         subdom = new real_t[ndims * 2];
-    }
+}
 
 BlockForest::BlockForest(PairsRuntime *ps_, const std::shared_ptr<walberla::blockforest::BlockForest> &bf) :
         forest(bf),
@@ -35,22 +36,13 @@ BlockForest::BlockForest(PairsRuntime *ps_, const std::shared_ptr<walberla::bloc
                         bf->getDomain().yMin(), bf->getDomain().yMax(),
                         bf->getDomain().zMin(), bf->getDomain().zMax()), 
         ps(ps_), 
-        globalPBC{bf->isXPeriodic(), bf->isYPeriodic(), bf->isZPeriodic()} 
-        {
+        globalPBC{bf->isXPeriodic(), bf->isYPeriodic(), bf->isZPeriodic()} {
             subdom = new real_t[ndims * 2];
-            balance_workload = 0;
-
             mpiManager = walberla::mpi::MPIManager::instance();
             world_size = mpiManager->numProcesses();
             rank = mpiManager->rank();
             this->info = make_shared<walberla::blockforest::InfoCollection>();
-
-            if(balance_workload) {
-                this->initializeWorkloadBalancer();
-            }
-
-        }
-
+}
 
 void BlockForest::updateNeighborhood() {
     std::map<int, std::vector<walberla::math::AABB>> neighborhood;
@@ -63,33 +55,26 @@ void BlockForest::updateNeighborhood() {
     naabbs.clear();
     aabb_offsets.clear();
     aabbs.clear();
-
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
-        auto& block_info = (*info)[block->getId()];
-
-        // don't check computationalWeight for now (TODO: compute_boundary_weights)
-        // if(block_info.computationalWeight > 0) {
-            for(uint neigh = 0; neigh < block->getNeighborhoodSize(); ++neigh) {
-                auto neighbor_rank = walberla::int_c(block->getNeighborProcess(neigh));
-
-                // if(neighbor_rank != me) {
-                    const walberla::BlockID& neighbor_block = block->getNeighborId(neigh);
-                    walberla::math::AABB neighbor_aabb = block->getNeighborAABB(neigh);
-                    auto neighbor_info = (*info)[neighbor_block];
-                    auto begin = blocks_pushed[neighbor_rank].begin();
-                    auto end = blocks_pushed[neighbor_rank].end();
-
-                    // if(neighbor_info.computationalWeight > 0 &&
-                    if(   find_if(begin, end, [neighbor_block](const auto &nbh) {
-                            return nbh == neighbor_block; }) == end) {
-
-                        neighborhood[neighbor_rank].push_back(neighbor_aabb);
-                        blocks_pushed[neighbor_rank].push_back(neighbor_block);
-                    }
-                // }
+        for(uint neigh = 0; neigh < block->getNeighborhoodSize(); ++neigh) {
+            auto neighbor_rank = walberla::int_c(block->getNeighborProcess(neigh));
+
+            // Neighbor blocks that belong to the same rank should be added to 
+            // neighboorhood only if there's PBC along any dim, otherwise they should be skipped.
+            // TODO: Make PBCs work with runtime load balancing
+            if((neighbor_rank != me) || globalPBC[0] || globalPBC[1] || globalPBC[2]) {
+                const walberla::BlockID& neighbor_id = block->getNeighborId(neigh);
+                walberla::math::AABB neighbor_aabb = block->getNeighborAABB(neigh);
+                auto begin = blocks_pushed[neighbor_rank].begin();
+                auto end = blocks_pushed[neighbor_rank].end();
+                
+                if(find_if(begin, end, [neighbor_id](const auto &bp) { return bp == neighbor_id; }) == end) {
+                    neighborhood[neighbor_rank].push_back(neighbor_aabb);
+                    blocks_pushed[neighbor_rank].push_back(neighbor_id);
+                }
             }
-        // }
+        }
     }
 
     for(auto& nbh: neighborhood) {
@@ -130,28 +115,41 @@ void BlockForest::updateWeights() {
     walberla::mpi::BufferSystem bs(mpiManager->comm(), 756);
 
     info->clear();
+
+    int sum_block_locals = 0;
+    // Compute the weights for my blocks and their children
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
         auto aabb = block->getAABB();
         auto& block_info = (*info)[block->getId()];
-        // TODO: Generate boundary weights
-        // pairs::compute_boundary_weights(
-        //     this->ps,
-        //     aabb.xMin(), aabb.xMax(), aabb.yMin(), aabb.yMax(), aabb.zMin(), aabb.zMax(),
-        //     &(block_info.computationalWeight), &(block_info.communicationWeight));
+
+        pairs::compute_boundary_weights(
+            this->ps,
+            aabb.xMin(), aabb.xMax(), aabb.yMin(), aabb.yMax(), aabb.zMin(), aabb.zMax(),
+            &(block_info.computationalWeight), &(block_info.communicationWeight));
+        
+        sum_block_locals += block_info.computationalWeight;
 
         for(int branch = 0; branch < 8; ++branch) {
             const auto b_id = walberla::BlockID(block->getId(), branch);
             const auto b_aabb = forest->getAABBFromBlockId(b_id);
             auto& b_info = (*info)[b_id];
 
-            // pairs::compute_boundary_weights(
-            //     this->ps,
-            //     b_aabb.xMin(), b_aabb.xMax(), b_aabb.yMin(), b_aabb.yMax(), b_aabb.zMin(), b_aabb.zMax(),
-            //     &(b_info.computationalWeight), &(b_info.communicationWeight));
+            pairs::compute_boundary_weights(
+                this->ps,
+                b_aabb.xMin(), b_aabb.xMax(), b_aabb.yMin(), b_aabb.yMax(), b_aabb.zMin(), b_aabb.zMax(),
+                &(b_info.computationalWeight), &(b_info.communicationWeight));
         }
     }
+    
+    int non_globals = ps->getTrackedVariableAsInteger("nlocal") - UniqueID::getNumGlobals();
+    
+    if(sum_block_locals!=non_globals){
+        std::cout << "Warning: " << non_globals - sum_block_locals << " particles in rank " << rank << 
+        " may get lost in the next rebalancing." << std::endl;
+    }
 
+    // Send the weights of my blocks and their children to the neighbors of my blocks
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
         auto& block_info = (*info)[block->getId()];
@@ -220,8 +218,8 @@ walberla::Vector3<int> BlockForest::getBlockConfig(int num_processes, int nx, in
 
 int BlockForest::getInitialRefinementLevel(int num_processes) {
     int splitFactor = 8;
-    int blocks = splitFactor;
-    int refinementLevel = 1;
+    int blocks = 1;
+    int refinementLevel = 0;
 
     while(blocks < num_processes) {
         refinementLevel++;
@@ -232,6 +230,9 @@ int BlockForest::getInitialRefinementLevel(int num_processes) {
 }
 
 void BlockForest::setBoundingBox() {
+    for (int i=0; i<6; ++i) subdom[i] = 0.0;
+    if (forest->empty()) return;
+
     auto aabb_union = forest->begin()->getAABB();
     for(auto& iblock: *forest) {
         auto block = static_cast<walberla::blockforest::Block *>(&iblock);
@@ -246,17 +247,6 @@ void BlockForest::setBoundingBox() {
     subdom[5] = aabb_union.zMax();
 }
 
-void BlockForest::rebalance() {
-    if(balance_workload) {
-        this->updateWeights();
-        forest->refresh();
-    }
-
-    this->updateWeights();
-    this->updateNeighborhood();
-    this->setBoundingBox();
-}
-
 void BlockForest::initialize(int *argc, char ***argv) {
     mpiManager = walberla::mpi::MPIManager::instance();
     mpiManager->initializeMPI(argc, argv);
@@ -272,41 +262,62 @@ void BlockForest::initialize(int *argc, char ***argv) {
     auto block_config = balance_workload ? walberla::Vector3<int>(1, 1, 1) :
                                            getBlockConfig(procs, gridsize[0], gridsize[1], gridsize[2]);
 
-    if(rank==0) std::cout << "block_config = " << block_config << std::endl;
-
     auto ref_level = balance_workload ? getInitialRefinementLevel(procs) : 0;
 
-    forest = walberla::blockforest::createBlockForest(
-        domain, block_config, walberla::Vector3<bool>(globalPBC[0], globalPBC[1], globalPBC[2]), procs, ref_level);
+    walberla::Vector3<bool> pbc(globalPBC[0], globalPBC[1], globalPBC[2]);
+
+    forest = walberla::blockforest::createBlockForest(domain, block_config, pbc, procs, ref_level);
 
     this->info = make_shared<walberla::blockforest::InfoCollection>();
 
-    if(balance_workload) {
-        this->initializeWorkloadBalancer();
+    if (rank==0) {
+        std::cout << "Domain: " << domain << std::endl;
+        std::cout << "PBC: " << pbc << std::endl;
+        std::cout << "Block config: " << block_config  << std::endl;
+        std::cout << "Initial refinement level: " << ref_level << std::endl;
+        std::cout << "Dynamic load balancing: " << (balance_workload ? "True" : "False") << std::endl;
     }
 }
 
 void BlockForest::update() {
     if(balance_workload) {
+        if(!forest->loadBalancingFunctionRegistered()){
+            std::cerr << "Workload balancer is not initialized." << std::endl;
+            exit(-1);
+        }
+
         this->updateWeights();
-        forest->refresh();
-    }
+        const int nlocal = ps->getTrackedVariableAsInteger("nlocal");
+        for(auto &prop: ps->getProperties()) {
+            if(!prop.isVolatile()) {
+                const int ptypesize = get_proptype_size(prop.getType());
+                ps->copyPropertyToHost(prop, pairs::WriteAfterRead, nlocal*ptypesize);
+            }
+        }
+        
+        // PAIRS_DEBUG("Rebalance\n");
+        if (rank==0) std::cout << "Rebalance" << std::endl;
+        forest->refresh(); 
+}
 
-    this->updateWeights();
     this->updateNeighborhood();
     this->setBoundingBox();
 }
 
-void BlockForest::initializeWorkloadBalancer() {
-    std::string algorithm = "morton";
+void BlockForest::initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) {
+    if (rank==0) {
+        std::cout << "Load balancing algorithm: " << getAlgorithmName(algorithm) << std::endl;
+        std::cout << "regridMin = " << regridMin << ", regirdMax = " << regridMax << std::endl;
+    }
+    this->balance_workload = true;  // balance_workload is set to true in case the forest has been initialized externally
     real_t baseWeight = 1.0;
-    real_t metisipc2redist = 1.0;
-    size_t regridMin = 10;
-    size_t regridMax = 100;
     int maxBlocksPerProcess = 100;
-    string metisAlgorithm = "none";
-    string metisWeightsToUse = "none";
-    string metisEdgeSource = "none";
+
+    // Metis-specific params
+    real_t metisipc2redist = 1.0;
+    string metisAlgorithm = "PART_GEOM_KWAY";
+    string metisWeightsToUse = "BOTH_WEIGHTS";
+    string metisEdgeSource = "EDGES_FROM_EDGE_WEIGHTS";
 
     forest->recalculateBlockLevelsInRefresh(true);
     forest->alwaysRebalanceInRefresh(true);
@@ -316,13 +327,12 @@ void BlockForest::initializeWorkloadBalancer() {
     forest->allowMultipleRefreshCycles(false);
     forest->checkForEarlyOutInRefresh(false);
     forest->checkForLateOutInRefresh(false);
+
+    // TODO: Define another functor that makes use of communicationWeight as well
     forest->setRefreshMinTargetLevelDeterminationFunction(
         walberla::blockforest::MinMaxLevelDetermination(info, regridMin, regridMax));
 
-    std::transform(algorithm.begin(), algorithm.end(), algorithm.begin(),
-        [](unsigned char c) { return std::tolower(c); });
-
-    if(algorithm == "morton") {
+    if(algorithm == Morton) {
         forest->setRefreshPhantomBlockDataAssignmentFunction(
             walberla::blockforest::WeightAssignmentFunctor(info, baseWeight));
         forest->setRefreshPhantomBlockDataPackFunction(
@@ -334,7 +344,7 @@ void BlockForest::initializeWorkloadBalancer() {
         prepFunc.setMaxBlocksPerProcess(maxBlocksPerProcess);
         forest->setRefreshPhantomBlockMigrationPreparationFunction(prepFunc);
 
-    } else if(algorithm == "hilbert") {
+    } else if(algorithm == Hilbert) {
         forest->setRefreshPhantomBlockDataAssignmentFunction(
             walberla::blockforest::WeightAssignmentFunctor(info, baseWeight));
         forest->setRefreshPhantomBlockDataPackFunction(
@@ -346,7 +356,7 @@ void BlockForest::initializeWorkloadBalancer() {
         prepFunc.setMaxBlocksPerProcess(maxBlocksPerProcess);
         forest->setRefreshPhantomBlockMigrationPreparationFunction(prepFunc);
 
-    } else if(algorithm == "metis") {
+    } else if(algorithm == Metis) {
         forest->setRefreshPhantomBlockDataAssignmentFunction(
             walberla::blockforest::MetisAssignmentFunctor(info, baseWeight));
         forest->setRefreshPhantomBlockDataPackFunction(
@@ -362,7 +372,7 @@ void BlockForest::initializeWorkloadBalancer() {
         prepFunc.setipc2redist(metisipc2redist);
         forest->setRefreshPhantomBlockMigrationPreparationFunction(prepFunc);
 
-    } else if(algorithm == "diffusive") {
+    } else if(algorithm == Diffusive) {
         forest->setRefreshPhantomBlockDataAssignmentFunction(
             walberla::blockforest::WeightAssignmentFunctor(info, baseWeight));
         forest->setRefreshPhantomBlockDataPackFunction(
@@ -373,6 +383,10 @@ void BlockForest::initializeWorkloadBalancer() {
         auto prepFunc = walberla::blockforest::DynamicDiffusionBalance<walberla::blockforest::WeightAssignmentFunctor::PhantomBlockWeight>(1, 1, false);
         forest->setRefreshPhantomBlockMigrationPreparationFunction(prepFunc);
     }
+    else {
+        std::cerr << "Invalid load balancing algorithm." << std::endl;
+        exit(-1);
+    }
 
     forest->addBlockData(make_shared<walberla::ParticleDataHandling>(ps), "Interface");
 }
diff --git a/runtime/domain/block_forest.hpp b/runtime/domain/block_forest.hpp
index d5d9d54..d814d02 100644
--- a/runtime/domain/block_forest.hpp
+++ b/runtime/domain/block_forest.hpp
@@ -47,7 +47,7 @@ private:
 public:
     BlockForest(
         PairsRuntime *ps_,
-        real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, bool pbcx, bool pbcy, bool pbcz);
+        real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, bool pbcx, bool pbcy, bool pbcz, bool balance_workload_);
 
     BlockForest(PairsRuntime *ps_, const std::shared_ptr<walberla::blockforest::BlockForest> &bf);
 
@@ -56,14 +56,17 @@ public:
     }
 
     void initialize(int *argc, char ***argv);
+    void initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax);
+
     void update();
     void finalize();
     int getWorldSize() const { return world_size; }
     int getRank() const { return rank; }
     int getNumberOfNeighborRanks() { return this->nranks; }
     int getNumberOfNeighborAABBs() { return this->total_aabbs; }
+    double getSubdomMin(int dim) const { return subdom[2*dim + 0];}
+    double getSubdomMax(int dim) const { return subdom[2*dim + 1];}
 
-    void initializeWorkloadBalancer();
     void updateNeighborhood();
     void updateWeights();
     walberla::math::Vector3<int> getBlockConfig(int num_processes, int nx, int ny, int nz);
diff --git a/runtime/domain/domain_partitioning.hpp b/runtime/domain/domain_partitioning.hpp
index 48d9646..3dfdaae 100644
--- a/runtime/domain/domain_partitioning.hpp
+++ b/runtime/domain/domain_partitioning.hpp
@@ -39,7 +39,10 @@ public:
 
     double getMin(int dim) const { return grid_min[dim]; }
     double getMax(int dim) const { return grid_max[dim]; }
+    virtual double getSubdomMin(int dim) const = 0;
+    virtual double getSubdomMax(int dim) const = 0;
     virtual void initialize(int *argc, char ***argv) = 0;
+    virtual void initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) = 0;
     virtual void update() = 0;
     virtual int getWorldSize() const = 0;
     virtual int getRank() const = 0;
diff --git a/runtime/domain/regular_6d_stencil.cpp b/runtime/domain/regular_6d_stencil.cpp
index 6da8b4c..96ea998 100644
--- a/runtime/domain/regular_6d_stencil.cpp
+++ b/runtime/domain/regular_6d_stencil.cpp
@@ -89,6 +89,8 @@ void Regular6DStencil::initialize(int *argc, char ***argv) {
     this->setBoundingBox();
 }
 
+void Regular6DStencil::initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) {}
+
 void Regular6DStencil::update() {}
 
 void Regular6DStencil::finalize() {
diff --git a/runtime/domain/regular_6d_stencil.hpp b/runtime/domain/regular_6d_stencil.hpp
index 7ed1867..b4a9e5c 100644
--- a/runtime/domain/regular_6d_stencil.hpp
+++ b/runtime/domain/regular_6d_stencil.hpp
@@ -51,6 +51,7 @@ public:
     void setConfig();
     void setBoundingBox();
     void initialize(int *argc, char ***argv);
+    void initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax);
     void update();
     void finalize();
 
@@ -58,6 +59,8 @@ public:
     int getRank() const { return rank; }
     int getNumberOfNeighborRanks() { return 6; }
     int getNumberOfNeighborAABBs() { return 6; }
+    double getSubdomMin(int dim) const { return subdom_min[dim];}
+    double getSubdomMax(int dim) const { return subdom_max[dim];}
 
     int isWithinSubdomain(real_t x, real_t y, real_t z);
     void copyRuntimeArray(const std::string& name, void *dest, const int size);
diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index 40c1381..4897a35 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -14,7 +14,9 @@ namespace pairs {
 
 void PairsRuntime::initDomain(
     int *argc, char ***argv,
-    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, bool pbcx, bool pbcy, bool pbcz) {
+    real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, 
+    bool pbcx, bool pbcy, bool pbcz, 
+    bool balance_workload) {
 
     int mpi_initialized=0;
     MPI_Initialized(&mpi_initialized);
@@ -38,7 +40,7 @@ void PairsRuntime::initDomain(
     
 #ifdef USE_WALBERLA
     else if(dom_part_type == BlockForestPartitioning) {
-        dom_part = new BlockForest(this, xmin, xmax, ymin, ymax, zmin, zmax, pbcx, pbcy, pbcz);
+        dom_part = new BlockForest(this, xmin, xmax, ymin, ymax, zmin, zmax, pbcx, pbcy, pbcz, balance_workload);
     } 
 #endif
 
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index cf1e65e..54abd6d 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -222,6 +222,10 @@ public:
         copyPropertyToDevice(getProperty(id), action, size);
     }
 
+    void copyPropertyToDevice(Property &prop, action_t action) {
+        copyPropertyToDevice(prop, action, prop.getTotalSize());
+    }
+    
     void copyPropertyToDevice(Property &prop, action_t action, size_t size);
 
     void copyPropertyToHost(property_t id, action_t action) {
@@ -243,6 +247,10 @@ public:
         return prop_flags;
     }
 
+    DeviceFlags* getArrayFlags(){
+        return array_flags;
+    }
+    
     // Contact properties
     ContactProperty &getContactProperty(property_t id);
     ContactProperty &getContactPropertyByName(std::string name);
@@ -313,7 +321,7 @@ public:
     void initDomain(
         int *argc, char ***argv,
         real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax,
-        bool pbcx = 0, bool pbcy = 0, bool pbcz = 0);
+        bool pbcx = 0, bool pbcy = 0, bool pbcz = 0, bool balance_workload = 0);
 
     template<typename Domain_T>
     void useDomain(const std::shared_ptr<Domain_T> &domain_ptr);
diff --git a/runtime/pairs_common.hpp b/runtime/pairs_common.hpp
index ba2c56c..7423742 100644
--- a/runtime/pairs_common.hpp
+++ b/runtime/pairs_common.hpp
@@ -35,7 +35,7 @@ typedef double real_t;
 //typedef float real_t;
 //#endif
 
-typedef unsigned long long int id_t;
+typedef uint64_t id_t;
 typedef int array_t;
 typedef int property_t;
 typedef int layout_t;
@@ -51,6 +51,18 @@ enum PropertyType {
     Prop_Quaternion
 };
 
+constexpr size_t get_proptype_size(PropertyType type){
+    switch (type) {
+        case pairs::Prop_Integer:       return sizeof(int);
+        case pairs::Prop_UInt64:        return sizeof(uint64_t);
+        case pairs::Prop_Real:          return sizeof(real_t);
+        case pairs::Prop_Vector:        return 3*sizeof(real_t);
+        case pairs::Prop_Matrix:        return 9*sizeof(real_t);
+        case pairs::Prop_Quaternion:    return 4*sizeof(real_t);
+        default:             return 0;
+    }
+}
+
 enum DataLayout {
     Invalid = -1,
     AoS = 0,
@@ -79,6 +91,23 @@ enum DomainPartitioners {
     BlockForestPartitioning = 2
 };
 
+enum LoadBalancingAlgorithms {
+    Morton = 0,
+    Hilbert = 1,
+    Metis = 2,
+    Diffusive = 3
+};
+
+constexpr const char* getAlgorithmName(LoadBalancingAlgorithms alg) {
+    switch (alg) {
+        case Morton:    return "Morton";
+        case Hilbert:   return "Hilbert";
+        case Metis:     return "Metis";
+        case Diffusive: return "Diffusive";
+        default:        return "Invalid";
+    }
+}
+
 #ifdef DEBUG
 #   include <assert.h>
 #   define PAIRS_DEBUG(...)     {                                                   \
diff --git a/runtime/property.hpp b/runtime/property.hpp
index 301a46b..fd2c5e4 100644
--- a/runtime/property.hpp
+++ b/runtime/property.hpp
@@ -43,7 +43,7 @@ public:
     int isVolatile() const { return vol != 0; }
     size_t getPrimitiveTypeSize() const {
         return  (type == Prop_Integer) ? sizeof(int) :
-                (type == Prop_UInt64) ? sizeof(unsigned long long int) :
+                (type == Prop_UInt64) ? sizeof(uint64_t) :
                 (type == Prop_Real) ? sizeof(real_t) :
                 (type == Prop_Vector) ? sizeof(real_t) :
                 (type == Prop_Matrix) ? sizeof(real_t) :
@@ -58,7 +58,7 @@ public:
 
 class UInt64Property : public Property {
 public:
-    inline unsigned long long int &operator()(int i) { return static_cast<unsigned long long int *>(h_ptr)[i]; }
+    inline uint64_t &operator()(int i) { return static_cast<uint64_t *>(h_ptr)[i]; }
 };
 
 class FloatProperty : public Property {
diff --git a/runtime/unique_id.hpp b/runtime/unique_id.hpp
index 2905355..cfc95a7 100644
--- a/runtime/unique_id.hpp
+++ b/runtime/unique_id.hpp
@@ -8,6 +8,7 @@ class UniqueID{
 public:
     inline static id_t create(PairsRuntime *pr);
     inline static id_t createGlobal(PairsRuntime *pr);
+    inline static id_t getNumGlobals();
 
 private:
     static const id_t capacity = 1000000000;   // max number of particles per rank
@@ -16,6 +17,10 @@ private:
 
 };
 
+inline id_t UniqueID::getNumGlobals(){
+    return globalCounter - 1;
+}
+
 inline id_t UniqueID::create(PairsRuntime *pr){
     id_t rank = static_cast<id_t>(pr->getDomainPartitioner()->getRank());
     id_t id = rank*capacity + counter;
diff --git a/runtime/vtk.cpp b/runtime/vtk.cpp
index 43502b0..db805a3 100644
--- a/runtime/vtk.cpp
+++ b/runtime/vtk.cpp
@@ -6,6 +6,123 @@
 
 namespace pairs {
 
+void vtk_write_aabb(PairsRuntime *ps, const char *filename, int num,
+    double xmin, double xmax, 
+    double ymin, double ymax, 
+    double zmin, double zmax){
+
+    std::string output_filename(filename);
+    const int prec = 8;
+    std::ostringstream filename_oss;
+
+    filename_oss << filename << "_" << num;
+    if(ps->getDomainPartitioner()->getWorldSize() > 1) {
+        filename_oss << "r" << ps->getDomainPartitioner()->getRank() ;
+    }
+
+    filename_oss <<".vtk";
+    std::ofstream out_file(filename_oss.str());
+
+    double aabb[3][3];
+    for (int d=0; d<3; ++d){
+        aabb[d][0] = ps->getDomainPartitioner()->getSubdomMin(d);
+        aabb[d][1] = ps->getDomainPartitioner()->getSubdomMax(d);
+    }
+
+    out_file << std::fixed << std::setprecision(prec);
+    if(out_file.is_open()) {
+        out_file << "# vtk DataFile Version 2.0\n";
+        out_file << "Subdomains\n";
+        out_file << "ASCII\n";
+        out_file << "DATASET POLYDATA\n";
+        out_file << "POINTS 8 double\n";
+
+        out_file << xmin << " " << ymin << " " << zmin << "\n";
+        out_file << xmax << " " << ymin << " " << zmin << "\n";
+        out_file << xmax << " " << ymax << " " << zmin << "\n";
+        out_file << xmin << " " << ymax << " " << zmin << "\n";
+        out_file << xmin << " " << ymin << " " << zmax << "\n";
+        out_file << xmax << " " << ymin << " " << zmax << "\n";
+        out_file << xmax << " " << ymax << " " << zmax << "\n";
+        out_file << xmin << " " << ymax << " " << zmax << "\n";
+
+        out_file << "POLYGONS 6 30\n";
+
+        out_file << "4 0 1 2 3 \n";
+        out_file << "4 4 5 6 7 \n";
+        out_file << "4 0 1 5 4 \n";
+        out_file << "4 3 2 6 7 \n";
+        out_file << "4 0 4 7 3 \n";
+        out_file << "4 1 2 6 5 \n";
+
+        out_file << "\n\n";
+        out_file.close();
+    }
+    else {
+        std::cerr << "vtk_write_aabb: Failed to open " << filename_oss.str() << std::endl;
+        exit(-1);
+    }
+
+}
+
+void vtk_write_subdom(PairsRuntime *ps, const char *filename, int timestep, int frequency){
+    std::string output_filename(filename);
+    const int prec = 8;
+    std::ostringstream filename_oss;
+
+    if(frequency != 0 && timestep % frequency != 0) {
+        return;
+    }
+
+    filename_oss << filename << "_";
+    if(ps->getDomainPartitioner()->getWorldSize() > 1) {
+        filename_oss << "r" << ps->getDomainPartitioner()->getRank() << "_";
+    }
+
+    filename_oss << timestep << ".vtk";
+    std::ofstream out_file(filename_oss.str());
+
+    double aabb[3][3];
+    for (int d=0; d<3; ++d){
+        aabb[d][0] = ps->getDomainPartitioner()->getSubdomMin(d);
+        aabb[d][1] = ps->getDomainPartitioner()->getSubdomMax(d);
+    }
+
+    out_file << std::fixed << std::setprecision(prec);
+    if(out_file.is_open()) {
+        out_file << "# vtk DataFile Version 2.0\n";
+        out_file << "Subdomains\n";
+        out_file << "ASCII\n";
+        out_file << "DATASET POLYDATA\n";
+        out_file << "POINTS 8 double\n";
+
+        out_file << aabb[0][0] << " " << aabb[1][0] << " " << aabb[2][0] << "\n";
+        out_file << aabb[0][1] << " " << aabb[1][0] << " " << aabb[2][0] << "\n";
+        out_file << aabb[0][1] << " " << aabb[1][1] << " " << aabb[2][0] << "\n";
+        out_file << aabb[0][0] << " " << aabb[1][1] << " " << aabb[2][0] << "\n";
+        out_file << aabb[0][0] << " " << aabb[1][0] << " " << aabb[2][1] << "\n";
+        out_file << aabb[0][1] << " " << aabb[1][0] << " " << aabb[2][1] << "\n";
+        out_file << aabb[0][1] << " " << aabb[1][1] << " " << aabb[2][1] << "\n";
+        out_file << aabb[0][0] << " " << aabb[1][1] << " " << aabb[2][1] << "\n";
+
+        out_file << "POLYGONS 6 30\n";
+
+        out_file << "4 0 1 2 3 \n";
+        out_file << "4 4 5 6 7 \n";
+        out_file << "4 0 1 5 4 \n";
+        out_file << "4 3 2 6 7 \n";
+        out_file << "4 0 4 7 3 \n";
+        out_file << "4 1 2 6 5 \n";
+
+        out_file << "\n\n";
+        out_file.close();
+    }
+    else {
+        std::cerr << "vtk_write_subdoms: Failed to open " << filename_oss.str() << std::endl;
+        exit(-1);
+    }
+}
+
 void vtk_write_data(
     PairsRuntime *ps, const char *filename, int start, int end, int timestep, int frequency) {
 
@@ -13,6 +130,7 @@ void vtk_write_data(
     auto masses = ps->getAsFloatProperty(ps->getPropertyByName("mass"));
     auto positions = ps->getAsVectorProperty(ps->getPropertyByName("position"));
     auto flags = ps->getAsIntegerProperty(ps->getPropertyByName("flags"));
+    auto radius = ps->getAsFloatProperty(ps->getPropertyByName("radius"));
     const int prec = 8;
     int n = end - start;
     std::ostringstream filename_oss;
@@ -32,6 +150,7 @@ void vtk_write_data(
     ps->copyPropertyToHost(masses, ReadOnly);
     ps->copyPropertyToHost(positions, ReadOnly);
     ps->copyPropertyToHost(flags, ReadOnly);
+    ps->copyPropertyToHost(radius, ReadOnly);
 
     for(int i = start; i < end; i++) {
         if(flags(i) & flags::INFINITE) {
@@ -43,7 +162,7 @@ void vtk_write_data(
         out_file << "# vtk DataFile Version 2.0\n";
         out_file << "Particle data\n";
         out_file << "ASCII\n";
-        out_file << "DATASET UNSTRUCTURED_GRID\n";
+        out_file << "DATASET POLYDATA\n";
         out_file << "POINTS " << n << " double\n";
 
         for(int i = start; i < end; i++) {
@@ -55,26 +174,21 @@ void vtk_write_data(
         }
 
         out_file << "\n\n";
-        out_file << "CELLS " << n << " " << (n * 2) << "\n";
-        for(int i = 0; i < n; i++) {
-            out_file << "1 " << i << "\n";
-        }
-
-        out_file << "\n\n";
-        out_file << "CELL_TYPES " << n << "\n";
+        out_file << "POINT_DATA " << n << "\n";
+        out_file << "SCALARS mass double 1\n";
+        out_file << "LOOKUP_TABLE default\n";
         for(int i = start; i < end; i++) {
             if(!(flags(i) & flags::INFINITE)) {
-                out_file << "1\n";
+                out_file << std::fixed << std::setprecision(prec) << masses(i) << "\n";
             }
         }
 
         out_file << "\n\n";
-        out_file << "POINT_DATA " << n << "\n";
-        out_file << "SCALARS mass double\n";
+        out_file << "SCALARS radius double 1\n";
         out_file << "LOOKUP_TABLE default\n";
         for(int i = start; i < end; i++) {
             if(!(flags(i) & flags::INFINITE)) {
-                out_file << std::fixed << std::setprecision(prec) << masses(i) << "\n";
+                out_file << std::fixed << std::setprecision(prec) << radius(i) << "\n";
             }
         }
 
diff --git a/runtime/vtk.hpp b/runtime/vtk.hpp
index 2bc766c..dcd97c0 100644
--- a/runtime/vtk.hpp
+++ b/runtime/vtk.hpp
@@ -4,7 +4,14 @@
 
 namespace pairs {
 
+void vtk_write_aabb(PairsRuntime *ps, const char *filename, int num,
+    double xmin, double xmax, 
+    double ymin, double ymax, 
+    double zmin, double zmax);
+
+void vtk_write_subdom(PairsRuntime *ps, const char *filename, int timestep, int frequency=1);
+
 void vtk_write_data(
-    PairsRuntime *ps, const char *filename, int start, int end, int timestep, int frequency);
+    PairsRuntime *ps, const char *filename, int start, int end, int timestep, int frequency=1);
 
 }
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index 24ccbc8..6525a98 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -2,6 +2,7 @@ from pairs.ir.types import Types
 from pairs.code_gen.cgen import CGen
 from pairs.code_gen.target import Target
 from pairs.sim.domain_partitioners import DomainPartitioners
+from pairs.sim.load_balancing_algorithms import LoadBalancingAlgorithms
 from pairs.sim.shapes import Shapes
 from pairs.sim.simulation import Simulation
 
@@ -69,3 +70,15 @@ def regular_domain_partitioner_xy():
 
 def block_forest():
     return DomainPartitioners.BlockForest
+
+def morton():
+    return LoadBalancingAlgorithms.Morton
+
+def hilbert():
+    return LoadBalancingAlgorithms.Hilbert
+
+def metis():
+    return LoadBalancingAlgorithms.Metis
+
+def diffusive():
+    return LoadBalancingAlgorithms.Diffusive
\ No newline at end of file
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 82b082f..d4e0985 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -376,23 +376,14 @@ class CGen:
         self.print("public:")
         self.print.add_indent(4)
 
+        self.print("PairsRuntime* getPairsRuntime() {")
+        self.print("    return pairs_runtime;")
+        self.print("}")
+
         # Only interface modules are generated in the PairsSimulation class
         for module in self.sim.interface_modules():
             self.generate_module(module)
 
-        # Generate a 'use_domain' module only if domain is not predefined in the input script
-        if not self.sim.create_domain_at_initialization:
-            self.print("template<typename Domain_T>")
-            self.print("void use_domain(const std::shared_ptr<Domain_T> &domain_ptr) {")
-            self.print("    pairs_runtime->useDomain(domain_ptr);")
-            self.print("}")
-            self.print("")
-
-        self.print("void vtk_write(const char* filename, int start, int end, int timestep, int frequency) {")
-        self.print("    pairs::vtk_write_data(pairs_runtime, filename, start, end, timestep, frequency);")
-        self.print("}")
-        self.print("")
-
         self.print.add_indent(-4)
         self.print("};")
 
diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
index 6255ff4..6ed3f7f 100644
--- a/src/pairs/code_gen/interface.py
+++ b/src/pairs/code_gen/interface.py
@@ -3,7 +3,7 @@ from pairs.ir.functions import Call_Void, Call, Call_Int
 from pairs.ir.parameters import Parameter
 from pairs.ir.ret import Return
 from pairs.ir.scalars import ScalarOp
-from pairs.sim.domain import UpdateDomain, SetDomain
+from pairs.sim.domain import UpdateDomain
 from pairs.sim.cell_lists import BuildCellListsStencil
 from pairs.sim.comm import Synchronize, Borders, Exchange, ReverseComm
 from pairs.ir.types import Types
@@ -28,12 +28,8 @@ class InterfaceModules:
 
     def create_all(self):
         self.initialize()
-
-        # Generate a 'set_domain' module only if domain is not pre-set in the input script
-        if not self.sim.create_domain_at_initialization:
-            self.set_domain()
-
         self.setup_sim()
+        self.update_domain()
         self.update_cells(self.sim.reneighbor_frequency) 
         self.communicate(self.sim.reneighbor_frequency)
         self.reverse_comm() 
@@ -51,9 +47,6 @@ class InterfaceModules:
         self.nlocal()
         self.nghost()
         self.size()
-        self.create_sphere()
-        self.create_halfspace()
-        self.dem_sc_grid()
         self.end()      
 
     @pairs_interface_block
@@ -84,12 +77,6 @@ class InterfaceModules:
             self.sim.grid = MutableGrid(self.sim, self.sim.dims)
             self.sim.add_statement(inits)
 
-    @pairs_interface_block
-    def set_domain(self):
-        assert isinstance(self.sim.grid, MutableGrid)
-        self.sim.module_name('set_domain')
-        self.sim.add_statement(SetDomain(self.sim))
-
     @pairs_interface_block
     def setup_sim(self):
         self.sim.module_name('setup_sim')
@@ -102,9 +89,21 @@ class InterfaceModules:
             Assign(self.sim, self.sim.cell_lists.cutoff_radius, Parameter(self.sim, 'cutoff_radius', Types.Real))
 
         self.sim.add_statement(self.sim.setup_particles)
-        self.sim.add_statement(UpdateDomain(self.sim))
+        # This update assumes all particles have been created exactly in the rank that contains them 
+        self.sim.add_statement(UpdateDomain(self.sim))  
         self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))
-    
+        
+    @pairs_interface_block
+    def update_domain(self):
+        self.sim.module_name('update_domain')
+        self.sim.add_statement(Exchange(self.sim._comm))    # Local particles must be contained in their owners before domain update
+        self.sim.add_statement(UpdateDomain(self.sim))
+        # Exchange is not needed after update since all locals are contained in thier owners
+        self.sim.add_statement(Borders(self.sim._comm))     # Ghosts must be recreated after update
+        self.sim.add_statement(ResetVolatileProperties(self.sim))   # Reset volatile includes the new locals
+        self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))    # Rebuild stencil since subdom sizes have changed
+        self.sim.add_statement(self.sim.update_cells_procedures)
+        
     @pairs_interface_block
     def reset_volatiles(self):
         self.sim.module_name('reset_volatiles')
@@ -118,19 +117,8 @@ class InterfaceModules:
             ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
             ScalarOp.cmp(timestep, 0)
             ))
-        
-        subroutines = [BuildCellLists(self.sim, self.sim.cell_lists),
-                  PartitionCellLists(self.sim, self.sim.cell_lists)]
-        
-        # Add routine to build neighbor-lists per cell
-        if self.sim._store_neighbors_per_cell:
-            subroutines.append(BuildCellNeighborLists(self.sim, self.sim.cell_lists))
-
-        # Add routine to build neighbor-lists per particle (standard Verlet Lists)
-        if self.sim.neighbor_lists:
-            subroutines.append(BuildNeighborLists(self.sim, self.sim.neighbor_lists))
 
-        self.sim.add_statement(Filter(self.sim, cond, Block.from_list(self.sim, subroutines)))
+        self.sim.add_statement(Filter(self.sim, cond, self.sim.update_cells_procedures))
 
     @pairs_interface_block
     def communicate(self, reneighbor_frequency=1):
@@ -142,11 +130,15 @@ class InterfaceModules:
             ))
         
         exchange = Filter(self.sim, cond, Exchange(self.sim._comm))
-        border_sync = Branch(self.sim, cond, blk_if = Borders(self.sim._comm), 
+        border_sync = Branch(self.sim, cond, 
+                             blk_if = Borders(self.sim._comm), 
                              blk_else = Synchronize(self.sim._comm))
         
         self.sim.add_statement(exchange)
         self.sim.add_statement(border_sync)
+        
+        # TODO: Maybe remove this from here, but volatiles must always be reset after exchange
+        self.sim.add_statement(Filter(self.sim, cond, Block(self.sim, ResetVolatileProperties(self.sim))))   
 
     @pairs_interface_block
     def reverse_comm(self):
diff --git a/src/pairs/ir/types.py b/src/pairs/ir/types.py
index ab27939..c2548cd 100644
--- a/src/pairs/ir/types.py
+++ b/src/pairs/ir/types.py
@@ -24,8 +24,8 @@ class Types:
             else 'float' if t == Types.Float
             else 'double' if t == Types.Double
             else 'int' if t == Types.Int32
-            else 'long long int' if t == Types.Int64
-            else 'unsigned long long int' if t == Types.UInt64
+            else 'int64_t' if t == Types.Int64
+            else 'uint64_t' if t == Types.UInt64
             else 'bool' if t == Types.Boolean
             else 'void' if t == Types.Void
             else '<invalid type>'
@@ -38,8 +38,8 @@ class Types:
             else 'float' if t == Types.Float
             else 'double' if t == Types.Double
             else 'int' if t == Types.Int32
-            else 'long long int' if t == Types.Int64
-            else 'unsigned long long int' if t == Types.UInt64
+            else 'int64_t' if t == Types.Int64
+            else 'uint64_t' if t == Types.UInt64
             else 'bool' if t == Types.Boolean
             else 'void' if t == Types.Void
             else '<invalid type>'
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 6ec0179..00622af 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -78,6 +78,7 @@ class Borders(Lowerable):
         # exists in any force calculation kernel)
         # We ignore normal because there should be no ghost half-spaces
         prop_names = [
+            'flags',
             'uid',
             'type',
             'mass',
diff --git a/src/pairs/sim/domain.py b/src/pairs/sim/domain.py
index 9a85792..2560f1a 100644
--- a/src/pairs/sim/domain.py
+++ b/src/pairs/sim/domain.py
@@ -1,10 +1,6 @@
 from pairs.ir.block import pairs_inline
-from pairs.ir.parameters import Parameter
-from pairs.ir.types import Types
-from pairs.ir.assign import Assign
 from pairs.sim.lowerable import Lowerable
 
-
 class InitializeDomain(Lowerable):
     def __init__(self, sim):
         super().__init__(sim)
@@ -13,22 +9,6 @@ class InitializeDomain(Lowerable):
     def lower(self):
         self.sim.domain_partitioning().initialize()
 
-class SetDomain(Lowerable):
-    def __init__(self, sim):
-        super().__init__(sim)
-
-    @pairs_inline
-    def lower(self):
-        for d in range(self.sim.ndims()):
-            dmin = Parameter(self.sim, f'd{d}_min', Types.Real)
-            Assign(self.sim, self.sim.grid.min(d), dmin)
-
-        for d in range(self.sim.ndims()):
-            dmax = Parameter(self.sim, f'd{d}_max', Types.Real)
-            Assign(self.sim, self.sim.grid.max(d), dmax)
-
-        self.sim.domain_partitioning().initialize()
-
 class UpdateDomain(Lowerable):
     def __init__(self, sim):
         super().__init__(sim)
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index de38396..ef0502d 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -8,8 +8,11 @@ from pairs.ir.types import Types
 from pairs.sim.flags import Flags
 from pairs.ir.lit import Lit
 from pairs.sim.grid import MutableGrid
-
-
+from pairs.ir.device import CopyArray
+from pairs.ir.contexts import Contexts
+from pairs.ir.actions import Actions
+from pairs.sim.load_balancing_algorithms import LoadBalancingAlgorithms
+from pairs.ir.print import PrintCode
 class DimensionRanges:
     def __init__(self, sim):
         self.sim                = sim
@@ -96,6 +99,9 @@ class DimensionRanges:
 class BlockForest:
     def __init__(self, sim):
         self.sim                = sim
+        self.load_balancer      = None
+        self.regrid_min         = None
+        self.regrid_max         = None
         self.reduce_step        = sim.add_var('reduce_step', Types.Int32)   # this var is treated as a tmp (workaround for gpu)
         self.reduce_step.force_read = True
         self.rank               = sim.add_var('rank', Types.Int32)
@@ -136,7 +142,17 @@ class BlockForest:
 
     def initialize(self):
         grid_array = [(self.sim.grid.min(d), self.sim.grid.max(d)) for d in range(self.sim.ndims())]
-        Call_Void(self.sim, "pairs_runtime->initDomain", [param for delim in grid_array for param in delim] + self.sim._pbc)
+
+        Call_Void(self.sim, "pairs_runtime->initDomain", 
+                  [param for delim in grid_array for param in delim] + 
+                  self.sim._pbc + ([True] if self.load_balancer is not None else []))
+        
+        if self.load_balancer is not None:
+            PrintCode(self.sim, "pairs_runtime->getDomainPartitioner()->initWorkloadBalancer"
+                      f"({LoadBalancingAlgorithms.c_keyword(self.load_balancer)}, {self.regrid_min}, {self.regrid_max});")
+
+            # Call_Void(self.sim, "pairs_runtime->getDomainPartitioner()->initWorkloadBalancer", 
+            #           [self.load_balancer, self.regrid_min, self.regrid_max])
 
     def update(self):
         Call_Void(self.sim, "pairs_runtime->updateDomain", [])
@@ -155,6 +171,12 @@ class BlockForest:
             for _ in Filter(self.sim, self.aabb_capacity < self.ntotal_aabbs):
                 Assign(self.sim, self.aabb_capacity, self.ntotal_aabbs + 20)
                 self.aabbs.realloc()
+            
+            CopyArray(self.sim, self.ranks, Contexts.Host, Actions.WriteOnly, self.nranks)
+            CopyArray(self.sim, self.naabbs, Contexts.Host, Actions.WriteOnly, self.nranks)
+            CopyArray(self.sim, self.aabb_offsets, Contexts.Host, Actions.WriteOnly, self.nranks)
+            CopyArray(self.sim, self.aabbs, Contexts.Host, Actions.WriteOnly, self.ntotal_aabbs * 6)
+            CopyArray(self.sim, self.subdom, Contexts.Host, Actions.WriteOnly)
 
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['ranks', self.ranks, self.nranks])
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['naabbs', self.naabbs, self.nranks])
diff --git a/src/pairs/sim/load_balancing_algorithms.py b/src/pairs/sim/load_balancing_algorithms.py
new file mode 100644
index 0000000..165d151
--- /dev/null
+++ b/src/pairs/sim/load_balancing_algorithms.py
@@ -0,0 +1,13 @@
+class LoadBalancingAlgorithms:
+    Morton = 0
+    Hilbert = 1
+    Diffusive = 3
+    Metis = 2
+
+    def c_keyword(algorithm):
+        return "Hilbert"        if algorithm == LoadBalancingAlgorithms.Hilbert else \
+               "Morton"         if algorithm == LoadBalancingAlgorithms.Morton else \
+               "Diffusive"      if algorithm == LoadBalancingAlgorithms.Diffusive else \
+               "Metis"          if algorithm == LoadBalancingAlgorithms.Metis else \
+               "Invalid"
+    
\ No newline at end of file
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 2c0e194..f7360b4 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -17,9 +17,10 @@ from pairs.sim.comm import Comm, Synchronize, Borders, Exchange, ReverseComm
 from pairs.sim.contact_history import ContactHistory, BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus
 from pairs.sim.copper_fcc_lattice import CopperFCCLattice
 from pairs.sim.dem_sc_grid import DEMSCGrid
-from pairs.sim.domain import InitializeDomain, UpdateDomain, SetDomain
+from pairs.sim.domain import InitializeDomain, UpdateDomain
 from pairs.sim.domain_partitioners import DomainPartitioners
 from pairs.sim.domain_partitioning import BlockForest, DimensionRanges
+from pairs.sim.load_balancing_algorithms import LoadBalancingAlgorithms
 from pairs.sim.features import AllocateFeatureProperties
 from pairs.sim.grid import Grid2D, Grid3D, MutableGrid
 from pairs.sim.instrumentation import RegisterMarkers, RegisterTimers
@@ -83,6 +84,7 @@ class Simulation:
         self.cell_lists = None
         self._store_neighbors_per_cell = False
         self.neighbor_lists = None
+        self.update_cells_procedures = Block(self, [])
 
         # Context information used to partially build the program AST
         self.scope = []
@@ -131,6 +133,7 @@ class Simulation:
         self.dims = dims                        # Number of dimensions
         self.ntimesteps = timesteps             # Number of time-steps
         self.reneighbor_frequency = 1           # Re-neighbor frequency
+        self.rebalance_frequency = 0            # Re-balance frequency for dynamic load balancing
         self._target = None                     # Hardware target info
         self._pbc = [True for _ in range(dims)] # PBC flags for each dimension
         self._shapes = shapes                   # List of shapes used in the simulation
@@ -151,6 +154,14 @@ class Simulation:
 
         else:
             raise Exception("Invalid domain partitioner.")
+        
+    def set_workload_balancer(self, algorithm=LoadBalancingAlgorithms.Morton, 
+                              regrid_min=100, regrid_max=1000, rebalance_frequency=0):
+        assert self._partitioner == DomainPartitioners.BlockForest, "Load balancing is only supported by BlockForest."
+        self.rebalance_frequency = rebalance_frequency
+        self._dom_part.load_balancer = algorithm
+        self._dom_part.regrid_min = regrid_min
+        self._dom_part.regrid_max = regrid_max
 
     def partitioner(self):
         return self._partitioner
@@ -489,13 +500,30 @@ class Simulation:
     def compute_thermo(self, every=0):
         self._compute_thermo = every
 
+    def create_update_cells_block(self):
+        subroutines = [
+            BuildCellLists(self, self.cell_lists),
+            PartitionCellLists(self, self.cell_lists)
+        ]
+
+        # Add routine to build neighbor-lists per cell
+        if self._store_neighbors_per_cell:
+            subroutines.append(BuildCellNeighborLists(self, self.cell_lists))
+
+        # Add routine to build neighbor-lists per particle (standard Verlet Lists)
+        if self.neighbor_lists is not None:
+            subroutines.append(BuildNeighborLists(self, self.neighbor_lists))
+
+        self.update_cells_procedures.add_statement(subroutines)
+
     def generate(self):
         """Generate the code for the simulation"""
         assert self._target is not None, "Target not specified!"
 
         # Initialize communication instance with the specified domain-partitioner
         self._comm = Comm(self, self._dom_part)
-        
+        self.create_update_cells_block()
+
         if self._generate_whole_program:
             self.generate_program()
         else:
@@ -531,30 +559,25 @@ class Simulation:
         # First steps executed during each time-step in the simulation
         timestep_procedures += self.pre_step_functions 
 
-        comm_routine = [
-            (Exchange(self._comm), every_reneighbor_params),
-            (Borders(self._comm), Synchronize(self._comm), every_reneighbor_params)
-            ]
-        
-        if self._generate_whole_program:
-            timestep_procedures += comm_routine
+        # Rebalancing routines
+        if self.rebalance_frequency:
+            update_domain_procedures = Block.from_list(self, [
+                Exchange(self._comm),
+                UpdateDomain(self),
+                Borders(self._comm),
+                ResetVolatileProperties(self),
+                BuildCellListsStencil(self, self.cell_lists),
+                self.update_cells_procedures
+                ])
 
-        update_cells =    [
-            (BuildCellLists(self, self.cell_lists), every_reneighbor_params),
-            (PartitionCellLists(self, self.cell_lists), every_reneighbor_params)
-        ]
+            timestep_procedures.append((update_domain_procedures, {'every': self.rebalance_frequency}))
 
-        # Add routine to build neighbor-lists per cell
-        if self._store_neighbors_per_cell:
-            update_cells.append(
-                (BuildCellNeighborLists(self, self.cell_lists), every_reneighbor_params))
+        # Communication routines
+        timestep_procedures += [(Exchange(self._comm), every_reneighbor_params),
+                                (Borders(self._comm), Synchronize(self._comm), every_reneighbor_params)]
 
-        # Add routine to build neighbor-lists per particle (standard Verlet Lists)
-        if self.neighbor_lists is not None:
-            update_cells.append(
-                (BuildNeighborLists(self, self.neighbor_lists), every_reneighbor_params))
-
-        timestep_procedures += update_cells
+        # Update acceleration data structures
+        timestep_procedures += [(self.update_cells_procedures, every_reneighbor_params)]
 
         # Add routines for contact history management
         if self._use_contact_history:
@@ -566,15 +589,13 @@ class Simulation:
             timestep_procedures.append(ResetContactHistoryUsageStatus(self, self._contact_history))
 
         # Reset volatile properties
-        if self._generate_whole_program:
-            timestep_procedures += [ResetVolatileProperties(self)]
+        timestep_procedures += [ResetVolatileProperties(self)]
 
-        # add computational kernels
+        # Add computational kernels
         timestep_procedures += self.functions
 
         # For whole-program-generation, add reverse_comm wherever needed in the timestep loop (eg: after computational kernels) like this:
-        if self._generate_whole_program:
-            timestep_procedures += [reverse_comm_module]
+        timestep_procedures += [reverse_comm_module]
 
         # Clear unused contact history
         if self._use_contact_history:
-- 
GitLab