diff --git a/examples/modular/sphere_box_global.cpp b/examples/modular/sphere_box_global.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..783bbaa691b4a9ba29804b421a77daf6b18099b7
--- /dev/null
+++ b/examples/modular/sphere_box_global.cpp
@@ -0,0 +1,101 @@
+#include <iostream>
+#include <memory>
+#include <iomanip>
+
+#include "sphere_box_global.hpp"
+
+// cmake -DINPUT_SCRIPT=../examples/modular/sphere_box_global.py -DWALBERLA_DIR=../../walberla -DBUILD_APP=ON -DUSER_SOURCE_FILES=../examples/modular/sphere_box_global.cpp -DCOMPILE_CUDA=ON ..
+
+void set_feature_properties(std::shared_ptr<PairsAccessor> &ac){
+    ac->setTypeStiffness(0,0, 1e6);
+    ac->setTypeStiffness(0,1, 1e6);
+    ac->setTypeStiffness(1,0, 1e6);
+    ac->setTypeStiffness(1,1, 1e6);
+    ac->syncTypeStiffness();
+
+    ac->setTypeDampingNorm(0,0, 300);
+    ac->setTypeDampingNorm(0,1, 300);
+    ac->setTypeDampingNorm(1,0, 300);
+    ac->setTypeDampingNorm(1,1, 300);
+    ac->syncTypeDampingNorm();
+
+    ac->setTypeFriction(0,0, 1.2);
+    ac->setTypeFriction(0,1, 1.2);
+    ac->setTypeFriction(1,0, 1.2);
+    ac->setTypeFriction(1,1, 1.2);
+    ac->syncTypeFriction();
+
+    ac->setTypeDampingTan(0,0, 300);
+    ac->setTypeDampingTan(0,1, 300);
+    ac->setTypeDampingTan(1,0, 300);
+    ac->setTypeDampingTan(1,1, 300);
+    ac->syncTypeDampingTan();
+}
+
+int main(int argc, char **argv) {
+    auto pairs_sim = std::make_shared<PairsSimulation>();
+    pairs_sim->initialize();
+
+    auto ac = std::make_shared<PairsAccessor>(pairs_sim.get());
+    set_feature_properties(ac);
+
+    auto pairs_runtime = pairs_sim->getPairsRuntime();
+
+    pairs_runtime->initDomain(&argc, &argv, 0, 0, 0, 30, 30, 30, false, false, false, true); 
+    pairs_runtime->getDomainPartitioner()->initWorkloadBalancer(pairs::Hilbert, 100, 800);
+
+    pairs::create_halfspace(pairs_runtime, 0,0,0,       1, 0, 0,    0, pairs::flags::INFINITE | pairs::flags::FIXED);
+    pairs::create_halfspace(pairs_runtime, 0,0,0,       0, 1, 0,    0, pairs::flags::INFINITE | pairs::flags::FIXED);
+    pairs::create_halfspace(pairs_runtime, 0,0,0,       0, 0, 1,    0, pairs::flags::INFINITE | pairs::flags::FIXED);
+    pairs::create_halfspace(pairs_runtime, 30,30,30,    -1, 0, 0,   0, pairs::flags::INFINITE | pairs::flags::FIXED);
+    pairs::create_halfspace(pairs_runtime, 30,30,30,    0, -1, 0,   0, pairs::flags::INFINITE | pairs::flags::FIXED);
+    pairs::create_halfspace(pairs_runtime, 30,30,30,    0, 0, -1,   0, pairs::flags::INFINITE | pairs::flags::FIXED); 
+
+    double radius = 0.5;
+    // Create a bed of small particles
+    pairs::dem_sc_grid(pairs_runtime, 30, 20, 5,  radius*2 , radius*2 , radius*2, radius*2,    2,      250,    2);
+
+    // Create 3 global bodies, one of which is fixed
+    pairs::create_box(pairs_runtime,    12, 12, 13.5,   0, 0, 0,    15, 2, 13,  20,    0,       pairs::flags::GLOBAL); 
+    pairs::create_sphere(pairs_runtime, 15, 20, 15,     0, 4, 0,                50, 4, 0,       pairs::flags::GLOBAL);
+    pairs::create_sphere(pairs_runtime, 15, 25, 4,      0, 0, 0,                50, 4, 0,       pairs::flags::GLOBAL | pairs::flags::FIXED); 
+    
+    // Use the diameter of small particles to set up the cell list
+    double lcw = radius * 2;
+    pairs_sim->setup_sim(lcw, lcw, lcw, lcw);
+    pairs_sim->update_mass_and_inertia();
+    pairs_sim->communicate(0);
+
+    int num_timesteps = 20000; 
+    int vtk_freq = 100;
+    int rebalance_freq = 2000;
+    double dt = 0.001;
+
+    pairs::vtk_write_subdom(pairs_runtime, "output/subdom_init", 0);
+    
+    for (int t=0; t<num_timesteps; ++t){
+        if ((t % vtk_freq==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
+        
+        if (t % rebalance_freq == 0){ 
+            pairs_sim->update_domain();
+        }
+        
+        pairs_sim->update_cells(t); 
+
+        pairs_sim->gravity(); 
+        
+        // All global and local interactions are contained within the 'spring_dashpot' module
+        // You have the option to call spring_dashpot before or after 'gravity' or any other force-update module
+        pairs_sim->spring_dashpot();     
+
+        pairs_sim->euler(dt); 
+        pairs_sim->communicate(t);
+        
+        if (t % vtk_freq==0){
+            pairs::vtk_with_rotation(pairs_runtime, pairs::Shapes::Box, "output/local_boxes", 0, pairs_sim->nlocal(), t);
+            pairs::vtk_with_rotation(pairs_runtime, pairs::Shapes::Sphere, "output/local_spheres", 0, pairs_sim->nlocal(), t);
+        }
+    }
+
+    pairs_sim->end();
+}
\ No newline at end of file
diff --git a/examples/modular/sphere_box_global.py b/examples/modular/sphere_box_global.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fbbd6d38d1a1d4b8b852a654bddc651f23efd85
--- /dev/null
+++ b/examples/modular/sphere_box_global.py
@@ -0,0 +1,121 @@
+import math
+import pairs
+import sys
+import os
+        
+def update_mass_and_inertia(i):
+    rotation_matrix[i] = diagonal_matrix(1.0)
+    rotation[i] = default_quaternion()
+
+    if is_sphere(i):
+        inv_inertia[i] = inversed(diagonal_matrix(0.4 * mass[i] * radius[i] * radius[i]))
+
+    elif is_box(i):
+        inv_inertia[i] = inversed(diagonal_matrix (
+            edge_length[i][1]*edge_length[i][1] + edge_length[i][2]*edge_length[i][2],
+            edge_length[i][0]*edge_length[i][0] + edge_length[i][2]*edge_length[i][2],
+            edge_length[i][0]*edge_length[i][0] + edge_length[i][1]*edge_length[i][1]) * (mass[i] / 12.0))
+        
+        axis = vector(1,0.5,1)
+        angle = -3.1415/6.0
+        rotation[i] = quaternion(axis, angle) * rotation[i]
+        rotation_matrix[i] = quaternion_to_rotation_matrix(rotation[i])
+
+    elif is_halfspace(i):
+        mass[i] = infinity
+        inv_inertia[i] = 0.0
+
+
+def spring_dashpot(i, j):
+    delta_ij = -penetration_depth(i, j)
+    skip_when(delta_ij < 0.0)
+    
+    velocity_wf_i = linear_velocity[i] + cross(angular_velocity[i], contact_point(i, j) - position[i])
+    velocity_wf_j = linear_velocity[j] + cross(angular_velocity[j], contact_point(i, j) - position[j])
+    
+    rel_vel = -(velocity_wf_i - velocity_wf_j)
+    rel_vel_n = dot(rel_vel, contact_normal(i, j))
+    rel_vel_t = rel_vel - rel_vel_n * contact_normal(i, j)
+
+    fNabs = stiffness[i,j] * delta_ij + max(damping_norm[i,j] * rel_vel_n, 0.0)
+    fN = fNabs * contact_normal(i, j)
+
+    fTabs = min(damping_tan[i,j] * length(rel_vel_t), friction[i, j] * fNabs)
+    fT =  fTabs * normalized(rel_vel_t)
+
+    partial_force = fN + fT
+    apply(force, partial_force)
+    apply(torque, cross(contact_point(i, j) - position[i], partial_force))
+
+def euler(i):
+    skip_when(is_fixed(i) or is_infinite(i))
+    inv_mass = 1.0 / mass[i]
+    position[i] +=  0.5 * inv_mass * force[i] * dt * dt + linear_velocity[i] * dt
+    linear_velocity[i] += inv_mass * force[i] * dt
+    wdot = rotation_matrix[i] * (inv_inertia[i] * torque[i]) * transposed(rotation_matrix[i])
+    phi = angular_velocity[i] * dt + 0.5 * wdot * dt * dt
+    rotation[i] = quaternion(phi, length(phi)) * rotation[i]
+    rotation_matrix[i] = quaternion_to_rotation_matrix(rotation[i])
+    angular_velocity[i] += wdot * dt
+
+def gravity(i):
+    force[i][2] -= mass[i] * gravity_SI
+
+
+file_name = os.path.basename(__file__)
+file_name_without_extension = os.path.splitext(file_name)[0]
+
+psim = pairs.simulation(
+    file_name_without_extension,
+    [pairs.sphere(), pairs.halfspace(), pairs.box()],
+    double_prec=True,
+    particle_capacity=1000000,
+    neighbor_capacity=20,
+    debug=True, 
+    generate_whole_program=False)
+
+
+target = sys.argv[1] if len(sys.argv[1]) > 1 else "none"
+
+if target == 'gpu':
+    psim.target(pairs.target_gpu())
+elif target == 'cpu':
+    psim.target(pairs.target_cpu())
+else:
+    print(f"Invalid target, use {sys.argv[0]} <cpu/gpu>")
+
+psim.add_position('position')
+psim.add_property('mass', pairs.real())
+psim.add_property('linear_velocity', pairs.vector())
+psim.add_property('angular_velocity', pairs.vector())
+psim.add_property('force', pairs.vector(), volatile=True)
+psim.add_property('torque', pairs.vector(), volatile=True)
+psim.add_property('radius', pairs.real())
+psim.add_property('normal', pairs.vector())
+psim.add_property('inv_inertia', pairs.matrix())
+psim.add_property('rotation_matrix', pairs.matrix())
+psim.add_property('rotation', pairs.quaternion())
+psim.add_property('edge_length', pairs.vector())
+
+ntypes = 2
+psim.add_feature('type', ntypes)
+psim.add_feature_property('type', 'stiffness', pairs.real(), [3000 for i in range(ntypes * ntypes)])
+psim.add_feature_property('type', 'damping_norm', pairs.real(), [10.0 for i in range(ntypes * ntypes)])
+psim.add_feature_property('type', 'damping_tan', pairs.real())
+psim.add_feature_property('type', 'friction', pairs.real())
+
+psim.set_domain_partitioner(pairs.block_forest())
+psim.pbc([False, False, False])
+psim.build_cell_lists()
+
+psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf })
+
+# 'compute_globals' enables computation of forces on global bodies
+psim.compute(spring_dashpot, compute_globals=True)
+psim.compute(euler, parameters={'dt': pairs.real()})
+
+gravity_SI = 9.81
+psim.compute(gravity, symbols={'gravity_SI': gravity_SI })
+
+psim.generate()
+
diff --git a/examples/modular/spring_dashpot.py b/examples/modular/spring_dashpot.py
index 191c000ca61962af8ebae77ab4ec1b97b433d399..d88c8a78cfbfac89b800c3a42715105271f600ae 100644
--- a/examples/modular/spring_dashpot.py
+++ b/examples/modular/spring_dashpot.py
@@ -36,6 +36,7 @@ def spring_dashpot(i, j):
     apply(torque, cross(contact_point(i, j) - position, partial_force))
 
 def euler(i):
+    skip_when(is_fixed(i) or is_infinite(i))
     inv_mass = 1.0 / mass[i]
     position[i] +=  0.5 * inv_mass * force[i] * dt * dt + linear_velocity[i] * dt
     linear_velocity[i] += inv_mass * force[i] * dt
diff --git a/runtime/create_body.cpp b/runtime/create_body.cpp
index 6c431f646db037670fa29b0d4e73ab3e35a420ae..5e23c89b43b12e29a2815de740d8efc7963b8c1e 100644
--- a/runtime/create_body.cpp
+++ b/runtime/create_body.cpp
@@ -18,7 +18,7 @@ id_t create_halfspace(PairsRuntime *pr,
 
     if(pr->getDomainPartitioner()->isWithinSubdomain(x, y, z) || flag & (flags::INFINITE | flags::GLOBAL) ){
         int n = pr->getTrackedVariableAsInteger("nlocal");
-        uid = (flag & flags::GLOBAL) ? UniqueID::createGlobal(pr) : UniqueID::create(pr);
+        uid = (flag & (flags::INFINITE | flags::GLOBAL)) ? UniqueID::createGlobal(pr) : UniqueID::create(pr);
         uids(n) = uid;
         positions(n, 0) = x;
         positions(n, 1) = y;
@@ -28,7 +28,7 @@ id_t create_halfspace(PairsRuntime *pr,
         normals(n, 2) = nz;
         types(n) = type;
         flags(n) = flag;
-        shapes(n) = 1;   // halfspace
+        shapes(n) = Shapes::Halfspace;
         pr->setTrackedVariableAsInteger("nlocal", n + 1);
     }
 
@@ -50,10 +50,11 @@ id_t create_sphere(PairsRuntime *pr,
     auto radii = pr->getAsFloatProperty(pr->getPropertyByName("radius"));
     auto positions = pr->getAsVectorProperty(pr->getPropertyByName("position"));
     auto velocities = pr->getAsVectorProperty(pr->getPropertyByName("linear_velocity"));
+    auto angular_velocity = pr->getAsVectorProperty(pr->getPropertyByName("angular_velocity"));
 
-    if(pr->getDomainPartitioner()->isWithinSubdomain(x, y, z)) {
+    if(pr->getDomainPartitioner()->isWithinSubdomain(x, y, z) || flag & (flags::INFINITE | flags::GLOBAL)) {
         int n = pr->getTrackedVariableAsInteger("nlocal");
-        uid = (flag & flags::GLOBAL) ? UniqueID::createGlobal(pr) : UniqueID::create(pr);
+        uid = (flag & (flags::INFINITE | flags::GLOBAL)) ? UniqueID::createGlobal(pr) : UniqueID::create(pr);
         uids(n) = uid;
         radii(n) = radius;
         masses(n) = ((4.0 / 3.0) * M_PI) * radius * radius * radius * density;
@@ -65,11 +66,57 @@ id_t create_sphere(PairsRuntime *pr,
         velocities(n, 2) = vz;
         types(n) = type;
         flags(n) = flag;
-        shapes(n) = 0;   // sphere
+        shapes(n) = Shapes::Sphere;
+        angular_velocity(n, 0) = 0;
+        angular_velocity(n, 1) = 0;
+        angular_velocity(n, 2) = 0;
         pr->setTrackedVariableAsInteger("nlocal", n + 1);
     }
     
     return uid;
 }
 
+id_t create_box(PairsRuntime *pr, 
+    double x, double y, double z, 
+    double vx, double vy, double vz, 
+    double ex, double ey, double ez, 
+    double density, int type, int flag){
+    // TODO: increase capacity if exceeded
+    id_t uid = 0;
+    auto uids = pr->getAsUInt64Property(pr->getPropertyByName("uid"));   
+    auto shapes = pr->getAsIntegerProperty(pr->getPropertyByName("shape"));
+    auto types = pr->getAsIntegerProperty(pr->getPropertyByName("type"));
+    auto flags = pr->getAsIntegerProperty(pr->getPropertyByName("flags"));
+    auto masses = pr->getAsFloatProperty(pr->getPropertyByName("mass"));
+    auto positions = pr->getAsVectorProperty(pr->getPropertyByName("position"));
+    auto velocities = pr->getAsVectorProperty(pr->getPropertyByName("linear_velocity"));
+    auto edge_length = pr->getAsVectorProperty(pr->getPropertyByName("edge_length"));
+    auto angular_velocity = pr->getAsVectorProperty(pr->getPropertyByName("angular_velocity"));
+
+    if(pr->getDomainPartitioner()->isWithinSubdomain(x, y, z) || flag & (flags::INFINITE | flags::GLOBAL)) {
+        int n = pr->getTrackedVariableAsInteger("nlocal");
+        uid = (flag & (flags::INFINITE | flags::GLOBAL)) ? UniqueID::createGlobal(pr) : UniqueID::create(pr);
+        uids(n) = uid;
+        edge_length(n, 0) = ex;
+        edge_length(n, 1) = ey;
+        edge_length(n, 2) = ez;
+        masses(n) = ex * ey * ez * density;
+        positions(n, 0) = x;
+        positions(n, 1) = y;
+        positions(n, 2) = z;
+        velocities(n, 0) = vx;
+        velocities(n, 1) = vy;
+        velocities(n, 2) = vz;
+        types(n) = type;
+        flags(n) = flag;
+        shapes(n) = Shapes::Box;
+        angular_velocity(n, 0) = 0;
+        angular_velocity(n, 1) = 0;
+        angular_velocity(n, 2) = 0;
+        pr->setTrackedVariableAsInteger("nlocal", n + 1);
+    }
+
+    return uid;
+}
+
 }
\ No newline at end of file
diff --git a/runtime/create_body.hpp b/runtime/create_body.hpp
index 995b1f6998940c09d484fad159ba0a382640a82b..299416408044955277cea10e8dbe0f31cf433eea 100644
--- a/runtime/create_body.hpp
+++ b/runtime/create_body.hpp
@@ -15,4 +15,10 @@ id_t create_sphere(PairsRuntime *pr,
                     double vx, double vy, double vz, 
                     double density, double radius, int type, int flag);
 
+id_t create_box(PairsRuntime *pr, 
+                    double x, double y, double z, 
+                    double vx, double vy, double vz, 
+                    double ex, double ey, double ez, 
+                    double density, int type, int flag);
+
 }
\ No newline at end of file
diff --git a/runtime/dem_sc_grid.cpp b/runtime/dem_sc_grid.cpp
index 119ec78d75cf50dbbc73a9033750bb791353e6b5..30b2cb017d7a93fae2ff8df56758677c220393af 100644
--- a/runtime/dem_sc_grid.cpp
+++ b/runtime/dem_sc_grid.cpp
@@ -71,7 +71,7 @@ int dem_sc_grid(PairsRuntime *ps, double xmax, double ymax, double zmax, double
             velocities(nparticles, 2) = 0.1 * realRandom<real_t>(-initial_velocity, initial_velocity);
             types(nparticles) = rand() % ntypes;
             flags(nparticles) = 0;
-            shapes(nparticles) = shapes::Sphere;
+            shapes(nparticles) = Shapes::Sphere;
 
             /*
             std::cout << uid(nparticles) << "," << types(nparticles) << "," << masses(nparticles) << "," << radius(nparticles) << ","
diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index 6efead8d8cb598c1fcd481beb35c12464c235830..d1c2d635d230c727853159c5f1ac4abec446591f 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -188,9 +188,9 @@ void PairsRuntime::copyArraySliceToDevice(
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !array_flags->isDeviceFlagSet(array_id)) {
             if(!array.isStatic()) {
-                PAIRS_DEBUG(
-                    "Copying array %s to device (offset=%lu, n=%lu)\n",
-                    array.getName().c_str(), offset, size);
+                // PAIRS_DEBUG(
+                //     "Copying array %s to device (offset=%lu, n=%lu)\n",
+                //     array.getName().c_str(), offset, size);
 
                 pairs::copy_slice_to_device(
                     array.getHostPointer(), array.getDevicePointer(), offset, size);
@@ -211,16 +211,16 @@ void PairsRuntime::copyArrayToDevice(Array &array, action_t action, size_t size)
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !array_flags->isDeviceFlagSet(array_id)) {
             if(array.isStatic()) {
-                PAIRS_DEBUG(
-                    "Copying static array %s to device (n=%lu)\n",
-                    array.getName().c_str(), size);
+                // PAIRS_DEBUG(
+                //     "Copying static array %s to device (n=%lu)\n",
+                //     array.getName().c_str(), size);
 
                 pairs::copy_static_symbol_to_device(
                     array.getHostPointer(), array.getDevicePointer(), size);
             } else {
-                PAIRS_DEBUG(
-                    "Copying array %s to device (n=%lu)\n",
-                    array.getName().c_str(), size);
+                // PAIRS_DEBUG(
+                //     "Copying array %s to device (n=%lu)\n",
+                //     array.getName().c_str(), size);
 
                 pairs::copy_to_device(array.getHostPointer(), array.getDevicePointer(), size);
             }
@@ -240,9 +240,9 @@ void PairsRuntime::copyArraySliceToHost(Array &array, action_t action, size_t of
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !array_flags->isHostFlagSet(array_id)) {
             if(!array.isStatic()) {
-                PAIRS_DEBUG(
-                    "Copying array %s to host (offset=%lu, n=%lu)\n",
-                    array.getName().c_str(), offset, size);
+                // PAIRS_DEBUG(
+                //     "Copying array %s to host (offset=%lu, n=%lu)\n",
+                //     array.getName().c_str(), offset, size);
 
                 pairs::copy_slice_to_host(
                     array.getDevicePointer(), array.getHostPointer(), offset, size);
@@ -263,13 +263,13 @@ void PairsRuntime::copyArrayToHost(Array &array, action_t action, size_t size) {
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !array_flags->isHostFlagSet(array_id)) {
             if(array.isStatic()) {
-                PAIRS_DEBUG(
-                    "Copying static array %s to host (n=%lu)\n", array.getName().c_str(), size);
+                // PAIRS_DEBUG(
+                //     "Copying static array %s to host (n=%lu)\n", array.getName().c_str(), size);
 
                 pairs::copy_static_symbol_to_host(
                     array.getDevicePointer(), array.getHostPointer(), size);
             } else {
-                PAIRS_DEBUG("Copying array %s to host (n=%lu)\n", array.getName().c_str(), size);
+                // PAIRS_DEBUG("Copying array %s to host (n=%lu)\n", array.getName().c_str(), size);
                 pairs::copy_to_host(array.getDevicePointer(), array.getHostPointer(), size);
             }
         }
@@ -287,7 +287,7 @@ void PairsRuntime::copyPropertyToDevice(Property &prop, action_t action, size_t
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !prop_flags->isDeviceFlagSet(prop_id)) {
-            PAIRS_DEBUG("Copying property %s to device (n=%lu)\n", prop.getName().c_str(), size);
+            // PAIRS_DEBUG("Copying property %s to device (n=%lu)\n", prop.getName().c_str(), size);
             pairs::copy_to_device(prop.getHostPointer(), prop.getDevicePointer(), size);
         }
     }
@@ -304,7 +304,7 @@ void PairsRuntime::copyPropertyToHost(Property &prop, action_t action, size_t si
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !prop_flags->isHostFlagSet(prop_id)) {
-            PAIRS_DEBUG("Copying property %s to host (n=%lu)\n", prop.getName().c_str(), size);
+            // PAIRS_DEBUG("Copying property %s to host (n=%lu)\n", prop.getName().c_str(), size);
             pairs::copy_to_host(prop.getDevicePointer(), prop.getHostPointer(), size);
         }
     }
@@ -323,8 +323,8 @@ void PairsRuntime::copyContactPropertyToDevice(
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !contact_prop_flags->isDeviceFlagSet(prop_id)) {
-            PAIRS_DEBUG("Copying contact property %s to device (n=%lu)\n",
-                contact_prop.getName().c_str(), size);
+            // PAIRS_DEBUG("Copying contact property %s to device (n=%lu)\n",
+            //     contact_prop.getName().c_str(), size);
 
             pairs::copy_to_device(
                 contact_prop.getHostPointer(), contact_prop.getDevicePointer(), size);
@@ -345,8 +345,8 @@ void PairsRuntime::copyContactPropertyToHost(
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(!contact_prop_flags->isHostFlagSet(contact_prop.getId())) {
-            PAIRS_DEBUG("Copying contact property %s to host (n=%lu)\n",
-                contact_prop.getName().c_str(), size);
+            // PAIRS_DEBUG("Copying contact property %s to host (n=%lu)\n",
+            //     contact_prop.getName().c_str(), size);
 
             pairs::copy_to_host(
                 contact_prop.getDevicePointer(), contact_prop.getHostPointer(), size);
@@ -363,8 +363,8 @@ void PairsRuntime::copyContactPropertyToHost(
 void PairsRuntime::copyFeaturePropertyToDevice(FeatureProperty &feature_prop) {
     const size_t n = feature_prop.getArraySize();
 
-    PAIRS_DEBUG("Copying feature property %s to device (n=%lu)\n",
-        feature_prop.getName().c_str(), n);
+    // PAIRS_DEBUG("Copying feature property %s to device (n=%lu)\n",
+    //     feature_prop.getName().c_str(), n);
 
     pairs::copy_static_symbol_to_device(
         feature_prop.getHostPointer(), feature_prop.getDevicePointer(), n);
@@ -644,4 +644,22 @@ void PairsRuntime::copyRuntimeArray(const std::string& name, void *dest, const i
     this->getDomainPartitioner()->copyRuntimeArray(name, dest, size);
 }
 
+void PairsRuntime::allReduceInplaceSum(real_t *red_buffer, int num_elems){
+    real_t *buff_ptr = red_buffer;
+    auto buff_array = getArrayByHostPointer(red_buffer);
+
+    #ifdef ENABLE_CUDA_AWARE_MPI
+    buff_ptr = (real_t *) buff_array.getDevicePointer();
+    #else
+    copyArrayToHost(buff_array, Ignore, num_elems * sizeof(real_t));
+    #endif
+
+    MPI_Allreduce(MPI_IN_PLACE, buff_ptr, num_elems, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
+
+    #ifndef ENABLE_CUDA_AWARE_MPI
+    copyArrayToDevice(buff_array, Ignore, num_elems * sizeof(real_t));
+    #endif
+}
+
+
 }
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index e87dec06224d830f8f15fdc2e593278c00d100e6..14ffc5a283394cf23806b6c59916778da2b9be07 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -355,6 +355,8 @@ public:
     int getNumberOfNeighborRanks() { return this->getDomainPartitioner()->getNumberOfNeighborRanks(); }
     int getNumberOfNeighborAABBs() { return this->getDomainPartitioner()->getNumberOfNeighborAABBs(); }
 
+    void allReduceInplaceSum(real_t *red_buffer, int num_elems);
+
     // Device functions
     void sync() { device_synchronize(); }
 
diff --git a/runtime/pairs_common.hpp b/runtime/pairs_common.hpp
index 74237423ee5d3de07462bcc40739484ef4fd9781..03766e1459d3a9a3da8d12c38335a6a52101fb1b 100644
--- a/runtime/pairs_common.hpp
+++ b/runtime/pairs_common.hpp
@@ -22,13 +22,13 @@ namespace flags{
     constexpr int GLOBAL   = 1 << 3 ;
 }
 
-namespace shapes{
-    enum Type {
-        Sphere = 0,
-        Halfspace = 1,
-        PointMass = 2
-    };
-}
+enum Shapes {
+    Sphere = 0,
+    Halfspace = 1,
+    PointMass = 2,
+    Box = 3
+};
+
 //#ifdef USE_DOUBLE_PRECISION
 typedef double real_t;
 //#else
@@ -127,6 +127,7 @@ constexpr const char* getAlgorithmName(LoadBalancingAlgorithms alg) {
 #   define PAIRS_ASSERT(a)      assert(a)
 #   define PAIRS_EXCEPTION(a)
 #else
+// #   define PAIRS_DEBUG(...) {printf(__VA_ARGS__);}
 #   define PAIRS_DEBUG(...)
 #   define PAIRS_ASSERT(a)
 #   define PAIRS_EXCEPTION(a)
@@ -135,5 +136,6 @@ constexpr const char* getAlgorithmName(LoadBalancingAlgorithms alg) {
 #define PAIRS_ERROR(...)        fprintf(stderr, __VA_ARGS__)
 #define MIN(a,b)                ((a) < (b) ? (a) : (b))
 #define MAX(a,b)                ((a) > (b) ? (a) : (b))
+#define SIGN(a)                 ((a) < 0 ? -1 : 1)
 
 }
\ No newline at end of file
diff --git a/runtime/vtk.cpp b/runtime/vtk.cpp
index b6235725c9d7d10cb38e4033a1e456b6b50ab32a..57de0eb5489594c6099969d881187c98989f3bb4 100644
--- a/runtime/vtk.cpp
+++ b/runtime/vtk.cpp
@@ -6,6 +6,107 @@
 
 namespace pairs {
 
+
+void vtk_with_rotation(
+    PairsRuntime *ps, Shapes shape, const char *filename, int start, int end, int timestep, int frequency) {
+
+    std::string output_filename(filename);
+    auto masses = ps->getAsFloatProperty(ps->getPropertyByName("mass"));
+    auto positions = ps->getAsVectorProperty(ps->getPropertyByName("position"));
+    auto radius = ps->getAsFloatProperty(ps->getPropertyByName("radius"));
+    auto rotation_matrix = ps->getAsMatrixProperty(ps->getPropertyByName("rotation_matrix"));
+    auto shapes = ps->getAsIntegerProperty(ps->getPropertyByName("shape"));
+    const int prec = 8;
+    int n = end - start;
+    std::ostringstream filename_oss;
+
+    if(frequency != 0 && timestep % frequency != 0) {
+        return;
+    }
+
+    filename_oss << filename << "_";
+    if(ps->getDomainPartitioner()->getWorldSize() > 1) {
+        filename_oss << "r" << ps->getDomainPartitioner()->getRank() << "_";
+    }
+
+    filename_oss << timestep << ".vtk";
+    std::ofstream out_file(filename_oss.str());
+
+    ps->copyPropertyToHost(masses, ReadOnly);
+    ps->copyPropertyToHost(positions, ReadOnly);
+    ps->copyPropertyToHost(radius, ReadOnly);
+    ps->copyPropertyToHost(rotation_matrix, ReadOnly);
+    ps->copyPropertyToHost(shapes, ReadOnly);
+
+    for(int i = start; i < end; i++) {
+        if(shapes(i) != shape) {
+            n--;
+        }
+    }
+
+    if(out_file.is_open()) {
+        out_file << "# vtk DataFile Version 2.0\n";
+        out_file << "Particle data\n";
+        out_file << "ASCII\n";
+        out_file << "DATASET POLYDATA\n";
+        out_file << "POINTS " << n << " double\n";
+
+        out_file << std::fixed << std::setprecision(prec);
+        for(int i = start; i < end; i++) {
+            if (shapes(i) == shape) {
+                out_file << positions(i, 0) << " ";
+                out_file << positions(i, 1) << " ";
+                out_file << positions(i, 2) << "\n";
+            }
+        }
+
+        out_file << "\n\n";
+        out_file << "POINT_DATA " << n << "\n";
+        out_file << "SCALARS mass double 1\n";
+        out_file << "LOOKUP_TABLE default\n";
+        for(int i = start; i < end; i++) {
+            if (shapes(i) == shape) {
+                out_file << masses(i) << "\n";
+            }
+        }
+
+        out_file << "\n\n";
+        out_file << "SCALARS radius double 1\n";
+        out_file << "LOOKUP_TABLE default\n";
+        for(int i = start; i < end; i++) {
+            if (shapes(i) == shape) {
+                out_file << radius(i) << "\n";
+            }
+        }
+
+        out_file << "\n\n";
+        out_file << "TENSORS rotation float\n";
+        for(int i = start; i < end; i++) {
+            if (shapes(i) == shape) {
+                out_file    << rotation_matrix(i, 0) << " " 
+                            << rotation_matrix(i, 3) << " " 
+                            << rotation_matrix(i, 6) << "\n";
+            
+                out_file    << rotation_matrix(i, 1) << " " 
+                            << rotation_matrix(i, 4) << " " 
+                            << rotation_matrix(i, 7) << "\n";
+            
+                out_file    << rotation_matrix(i, 2) << " " 
+                            << rotation_matrix(i, 5) << " " 
+                            << rotation_matrix(i, 8) << "\n";
+            }
+        }
+
+        out_file << "\n\n";
+        out_file.close();
+    }
+    else {
+        std::cerr << "Failed to open " << filename_oss.str() << std::endl;
+        exit(-1);
+    }
+}
+
+
 void vtk_write_aabb(PairsRuntime *ps, const char *filename, int num,
     double xmin, double xmax, 
     double ymin, double ymax, 
diff --git a/runtime/vtk.hpp b/runtime/vtk.hpp
index dcd97c020f1b49cdf82df548083687e1874f7c51..ac369d5c776f06ae626ab142e16e79dd74b81c7e 100644
--- a/runtime/vtk.hpp
+++ b/runtime/vtk.hpp
@@ -4,6 +4,9 @@
 
 namespace pairs {
 
+void vtk_with_rotation(
+    PairsRuntime *ps, Shapes shape, const char *filename, int start, int end, int timestep, int frequency=1);
+
 void vtk_write_aabb(PairsRuntime *ps, const char *filename, int num,
     double xmin, double xmax, 
     double ymin, double ymax, 
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index 6525a9814b0fc4729c0cf3e24e5a748c5db3ae4b..ec005e5cb90d50b2b50d47da3e6955fd787146d8 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -59,6 +59,9 @@ def point_mass():
 def sphere():
     return Shapes.Sphere
 
+def box():
+    return Shapes.Box
+
 def halfspace():
     return Shapes.Halfspace
 
diff --git a/src/pairs/analysis/__init__.py b/src/pairs/analysis/__init__.py
index 7b200b201ef6b1126275c6656c98419b36e2d89a..a7abf13a4babd0e07acb740fb361c9cce868b3eb 100644
--- a/src/pairs/analysis/__init__.py
+++ b/src/pairs/analysis/__init__.py
@@ -1,7 +1,7 @@
 import time
 from pairs.analysis.expressions import DetermineExpressionsTerminals, ResetInPlaceOperations, DetermineInPlaceOperations, ListDeclaredExpressions
 from pairs.analysis.blocks import DiscoverBlockVariants, DetermineExpressionsOwnership, DetermineParentBlocks
-from pairs.analysis.devices import FetchKernelReferences, MarkCandidateLoops
+from pairs.analysis.devices import FetchKernelReferences, MarkCandidateLoops, FetchDeviceCopies
 from pairs.analysis.modules import FetchModulesReferences, InferModulesReturnTypes
 
 
@@ -53,4 +53,7 @@ class Analysis:
         self.apply(MarkCandidateLoops())
 
     def infer_modules_return_types(self):
-        self.apply(InferModulesReturnTypes())
\ No newline at end of file
+        self.apply(InferModulesReturnTypes())
+
+    def fetch_device_copies(self):
+        self.apply(FetchDeviceCopies())
\ No newline at end of file
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 29e554e4606776693cdfd2dd784fc24d0b6995ea..c93ac99b42f22e04091c4a25212434ca93bb726d 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -9,6 +9,32 @@ from pairs.ir.visitor import Visitor
 from pairs.ir.vectors import VectorOp
 
 
+class FetchDeviceCopies(Visitor):
+    def __init__(self, ast=None):
+        super().__init__(ast)
+        self.module_stack = []
+
+    def visit_Module(self, ast_node):      
+        self.module_stack.append(ast_node)
+        self.visit_children(ast_node)
+        self.module_stack.pop()
+
+    def visit_CopyArray(self, ast_node):
+        print(self.module_stack[-1].name , " array = ", ast_node.array().name() )
+        self.module_stack[-1].add_device_copy(ast_node.array())
+
+    def visit_CopyProperty(self, ast_node):
+        self.module_stack[-1].add_device_copy(ast_node.prop())
+
+    def visit_CopyFeatureProperty(self, ast_node):
+        self.module_stack[-1].add_device_copy(ast_node.prop())
+
+    def visit_CopyContactProperty(self, ast_node):
+        self.module_stack[-1].add_device_copy(ast_node.contact_prop())
+
+    def visit_CopyVar(self, ast_node):
+        self.module_stack[-1].add_device_copy(ast_node.variable())
+
 class MarkCandidateLoops(Visitor):
     def __init__(self, ast=None):
         super().__init__(ast)
diff --git a/src/pairs/code_gen/accessor.py b/src/pairs/code_gen/accessor.py
index 34421cb64cde53f7a75a9079b39ee90c25b7b235..6f866bebae1f509539c43b5397ae84ff33fa486f 100644
--- a/src/pairs/code_gen/accessor.py
+++ b/src/pairs/code_gen/accessor.py
@@ -253,7 +253,7 @@ class PairsAcessor:
         
         self.print(f"{self.host_device_attr}{tkw} get{funcname}({params}) const{{")
 
-        if self.target.is_gpu():
+        if self.target.is_gpu() and prop.device_flag:
             self.ifdef_else("__CUDA_ARCH__", self.getter_body, [prop, True], self.getter_body, [prop, False])
         else:
             self.getter_body(prop, False)
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 76e22283330f531eed2f449a30eee798bf62c00e..028c09087aff54ffae963332b6f2e7b4549b20a5 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -45,6 +45,7 @@ class CGen:
         self.target = None
         self.print = None
         self.kernel_context = False
+        self.loop_scope = False
         self.generate_full_object_names = False
         self.ref = ref
         self.debug = debug
@@ -486,8 +487,8 @@ class CGen:
         self.generate_module_header(module, definition=True)
         self.print.add_indent(4)
 
-        if self.debug:
-            self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
+        # if self.debug:
+        #     self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
 
         if not module.interface:
             self.generate_module_declerations(module)
@@ -617,7 +618,10 @@ class CGen:
             self.print.add_indent(-4)
 
         if isinstance(ast_node, Continue):
-            self.print("continue;")
+            if self.loop_scope:
+                self.print("continue;")
+            else:
+                self.print("return;")
 
         # TODO: Why there are Decls for other types?
         if isinstance(ast_node, Decl):
@@ -755,7 +759,7 @@ class CGen:
                 for i in matrix_op.indexes_to_generate():
                     lhs = self.generate_expression(matrix_op.lhs, matrix_op.mem, index=i)
                     rhs = self.generate_expression(matrix_op.rhs, index=i)
-                    operator = vector_op.operator()
+                    operator = matrix_op.operator()
 
                     if operator.is_unary():
                         self.print(f"const {self.real_type()} {matrix_op.name()}_{dim} = {operator.symbol()}({lhs});")
@@ -848,7 +852,9 @@ class CGen:
                 self.print("#pragma omp parallel for")
 
             self.print(f"for(int {iterator} = {lower_range}; {iterator} < {upper_range}; {iterator}++) {{")
+            self.loop_scope = True
             self.generate_statement(ast_node.block)
+            self.loop_scope = False
             self.print("}")
 
 
@@ -1060,7 +1066,9 @@ class CGen:
         if isinstance(ast_node, While):
             cond = self.generate_expression(ast_node.cond)
             self.print(f"while({cond}) {{")
+            self.loop_scope = True
             self.generate_statement(ast_node.block)
+            self.loop_scope = False
             self.print("}")
 
         if isinstance(ast_node, Return):
diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
index 6ed3f7f2b511137000e87f8db64c4a70894efd35..cf59ca280fe9425b6fa8b364f4e643e60a051140 100644
--- a/src/pairs/code_gen/interface.py
+++ b/src/pairs/code_gen/interface.py
@@ -245,7 +245,7 @@ class InterfaceModules:
     @pairs_interface_block
     def end(self):
         self.sim.module_name('end')
-        Call_Void(self.sim, "pairs::print_timers", [])
+        # Call_Void(self.sim, "pairs::print_timers", [])
         Call_Void(self.sim, "pairs::print_stats", [self.sim.nlocal, self.sim.nghost])
         PrintCode(self.sim, "delete pobj;")
         PrintCode(self.sim, "delete pairs_runtime;")
diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 8842818627c7da514ec9a24ae85c84a5f08cd747..7d90c3c97e3fac7d4cb7cdf48b9d8b1d40720808 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -88,8 +88,8 @@ class For(ASTNode):
 
 
 class ParticleFor(For):
-    def __init__(self, sim, block=None, local_only=True):
-        super().__init__(sim, 0, sim.nlocal if local_only else sim.nlocal + sim.nghost, block)
+    def __init__(self, sim, block=None, local_only=True, not_kernel=False):
+        super().__init__(sim, 0, sim.nlocal if local_only else sim.nlocal + sim.nghost, block, not_kernel)
         self.local_only = local_only
 
     def __str__(self):
diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py
index a6a156a4986a6a8b122dc70d8ca920ad18d29269..d9bd5730230d1f209b5c598f8b3ff55254f36496 100644
--- a/src/pairs/ir/math.py
+++ b/src/pairs/ir/math.py
@@ -83,7 +83,48 @@ class Abs(MathFunction):
     def type(self):
         return self._params[0].type()
 
+class Min(MathFunction):
+    def __init__(self, sim, a, b):
+        super().__init__(sim)
+        self._params = [a, b]
+
+    def __str__(self):
+        return f"Min<{self._params}>"
+
+    def function_name(self):
+        return "MIN"
+
+    def type(self):
+        return self._params[0].type()
+    
+class Max(MathFunction):
+    def __init__(self, sim, a, b):
+        super().__init__(sim)
+        self._params = [a, b]
+
+    def __str__(self):
+        return f"Max<{self._params}>"
+
+    def function_name(self):
+        return "MAX"
 
+    def type(self):
+        return self._params[0].type()
+    
+class Sign(MathFunction):
+    def __init__(self, sim, expr):
+        super().__init__(sim)
+        self._params = [expr]
+
+    def __str__(self):
+        return f"Sign<{self._params}>"
+
+    def function_name(self):
+        return "SIGN"
+
+    def type(self):
+        return self._params[0].type()
+    
 class Sin(MathFunction):
     def __init__(self, sim, expr):
         super().__init__(sim)
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index ded67ac6f4448590c346f51177b17fe364b729d0..3f2df80f47f0d361b5f627b1478d9c87c742ff63 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -29,6 +29,7 @@ class Module(ASTNode):
         self._contact_properties = {}
         self._feature_properties = {}
         self._host_references = set()
+        self._device_copies = set()
         self._block = block
         self._resizes_to_check = resizes_to_check
         self._check_properties_resize = check_properties_resize
@@ -115,6 +116,9 @@ class Module(ASTNode):
     def host_references(self):
         return self._host_references
 
+    def device_copies(self):
+        return self._device_copies
+    
     def add_array(self, array, write=False):
         array_list = array if isinstance(array, list) else [array]
         new_op = 'w' if write else 'r'
@@ -185,6 +189,9 @@ class Module(ASTNode):
     def add_host_reference(self, elem):
         self._host_references.add(elem)
 
+    def add_device_copy(self, elem):
+        self._device_copies.add(elem)
+
     def children(self):
         return [self._block]
 
diff --git a/src/pairs/ir/symbols.py b/src/pairs/ir/symbols.py
index f7ba470282bc9a8e1f0ef88a4a30437bdc711fbe..0c313c715074954eb8c149ea1df6db7125804575 100644
--- a/src/pairs/ir/symbols.py
+++ b/src/pairs/ir/symbols.py
@@ -7,16 +7,20 @@ from pairs.ir.types import Types
 
 
 class Symbol(ASTTerm):
-    def __init__(self, sim, sym_type):
+    def __init__(self, sim, sym_type, name=None):
         super().__init__(sim, OperatorClass.from_type(sym_type))
         self.sym_type = sym_type
         self.assign_to = None
+        self.name = name
 
     def __str__(self):
         return f"Symbol<{Types.c_keyword(self.sim, self.sym_type)}>"
 
     def assign(self, node):
         self.assign_to = node
+    
+    def get_assigned_node(self):
+        return self.assign_to
 
     def type(self):
         return self.sym_type
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index 00bcc698214a3938b5ae5c3b9b6917114f596c5c..049496c10c32eac0cec29bdbeab84de169c8b27e 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -4,6 +4,8 @@ from pairs.ir.assign import Assign
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.lit import Lit
 from pairs.ir.operator_class import OperatorClass
+from pairs.ir.types import Types
+from pairs.ir.accessor_class import AccessorClass
 
 
 class Variables:
@@ -23,10 +25,10 @@ class Variables:
         self.vars.append(var)
         return var
 
-    def add_temp(self, init):
+    def add_temp(self, init, type):
         lit = Lit.cvt(self.sim, init)
         tmp_id = Variables.new_temp_id()
-        tmp_var = Var(self.sim, f"tmp{tmp_id}", lit.type(), temp=True)
+        tmp_var = Var(self.sim, f"tmp{tmp_id}", lit.type() if type is None else type, temp=True)
         Assign(self.sim, tmp_var, lit)
         return tmp_var
 
@@ -57,6 +59,11 @@ class Var(ASTTerm):
     def __str__(self):
         return f"Var<{self.var_name}>"
 
+    def __getitem__(self, index):
+        assert not Types.is_scalar(self.var_type)
+        _acc_class = AccessorClass.from_type(self.var_type)
+        return _acc_class(self.sim, self, Lit.cvt(self.sim, index))
+    
     def copy(self, deep=False):
         # Terminal copies are just themselves
         return self
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index b4ec2f8f5fcb083a92f3eb5a43f6f33cc49093a5..2523ba4129e5bfaff41c06c174f5188e33b33b42 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -13,7 +13,10 @@ from pairs.ir.types import Types
 from pairs.mapping.keywords import Keywords
 from pairs.sim.flags import Flags
 from pairs.sim.interaction import ParticleInteraction
-
+from pairs.sim.global_interaction import  GlobalLocalInteraction, GlobalGlobalInteraction, GlobalReduction, SortGlobals, PackGlobals, ResetReductionProps, ReduceGlobals, UnpackGlobals
+from pairs.ir.module import Module
+from pairs.ir.block import Block, pairs_device_block
+from pairs.sim.lowerable import Lowerable
 
 class UndefinedSymbol():
     def __init__(self, symbol_id):
@@ -284,11 +287,30 @@ class BuildParticleIR(ast.NodeVisitor):
         op_class = OperatorClass.from_type(operand.type())
         return op_class(self.sim, operand, None, BuildParticleIR.get_unary_op(node.op))
 
+class OneBodyKernel(Lowerable):
+    def __init__(self, sim, module_name):
+        super().__init__(sim)
+        self.block = Block(sim, [])
+        self.module_name = module_name
+
+    def add_statement(self, stmt):
+        self.block.add_statement(stmt)
+
+    def __iter__(self):
+        self.sim.add_statement(self)
+        self.sim.enter(self)
+        for i in ParticleFor(self.sim):
+            yield i
+        self.sim.leave()
+    
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name(self.module_name)
+        self.sim.add_statement(self.block)
 
-def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False):
+def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, compute_globals=False):
     if sim._generate_whole_program:
         assert not parameters, "Compute functions can't take custom parameters when generating whole program."
-    
 
     src = inspect.getsource(func)
     tree = ast.parse(src, mode='exec')
@@ -311,14 +333,15 @@ def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=F
     sim.module_name(func.__name__)
 
     if nparams == 1:
-        for i in ParticleFor(sim):
-            for _ in Filter(sim, ScalarOp.cmp(sim.particle_flags[i] & Flags.Fixed, 0)):
-                ir = BuildParticleIR(sim, symbols, parameters)
-                ir.add_symbols({params[0]: i})
-                ir.visit(tree)
+        for i in OneBodyKernel(sim, func.__name__):
+            ir = BuildParticleIR(sim, symbols, parameters)
+            ir.add_symbols({params[0]: i})
+            ir.visit(tree)
 
     else:
-        for interaction_data in ParticleInteraction(sim, nparams, cutoff_radius):
+        # Compute local-local and local-global interactions
+        particle_interaction = ParticleInteraction(sim, func.__name__, nparams, cutoff_radius)
+        for interaction_data in particle_interaction:
             # Start building IR
             ir = BuildParticleIR(sim, symbols, parameters)
             ir.add_symbols({
@@ -332,43 +355,59 @@ def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=F
                 '__contact_point__': interaction_data.contact_point(),
                 '__contact_normal__': interaction_data.contact_normal()
             })
-
             ir.visit(tree)
 
-    if sim._generate_whole_program:
-        if pre_step:
-            sim.build_pre_step_module_with_statements(skip_first=skip_first, profile=True)
-        else:
-            sim.build_module_with_statements(skip_first=skip_first, profile=True)
-    else:
-        sim.build_user_defined_function()
-
-def setup(sim, func, symbols={}):
-    src = inspect.getsource(func)
-    tree = ast.parse(src, mode='exec')
-
-    # Fetch function info
-    info = FetchParticleFuncInfo()
-    info.visit(tree)
-    params = info.params()
-    nparams = info.nparams()
-
-    # Compute functions must have parameters
-    assert nparams == 1, "Number of parameters from setup functions must be one!"
+        if compute_globals:
+            # If compute_globals is enabled, global interactions and reductions become 
+            # part of the same module as local interactions
+            global_reduction = GlobalReduction(sim, func.__name__, particle_interaction)
+            
+            SortGlobals(global_reduction)           # Sort global bodies with respect to their uid
+            PackGlobals(global_reduction)           # Pack reduction properties in uid order in an intermediate buffer
+            ResetReductionProps(global_reduction)   # Reset reduction properties to be prepared for local reduction 
+            
+            # Compute local contirbutions on global bodies
+            global_local_interactions = GlobalLocalInteraction(sim, func.__name__, nparams)
+            for interaction_data in global_local_interactions:
+                ir = BuildParticleIR(sim, symbols, parameters)
+                ir.add_symbols({
+                    params[0]: interaction_data.i(),
+                    params[1]: interaction_data.j(),
+                    '__i__': interaction_data.i(),
+                    '__j__': interaction_data.j(),
+                    '__delta__': interaction_data.delta(),
+                    '__squared_distance__': interaction_data.squared_distance(),
+                    '__penetration_depth__': interaction_data.penetration_depth(),
+                    '__contact_point__': interaction_data.contact_point(),
+                    '__contact_normal__': interaction_data.contact_normal()
+                })
+                ir.visit(tree)
 
-    # Convert literal symbols
-    symbols = {symbol: Lit.cvt(sim, value) for symbol, value in symbols.items()}
 
-    sim.init_block()
-    sim.module_name(func.__name__)
+            PackGlobals(global_reduction, False)    # Pack local contributions in reduction buffer in uid order
+            ReduceGlobals(global_reduction)         # Inplace reduce local contributions over global bodies in the reduction buffer
+            UnpackGlobals(global_reduction)         # Add the reduced properties with the intermediate buffer and unpack into global bodies 
 
-    for i in ParticleFor(sim):
-        ir = BuildParticleIR(sim, symbols)
-        ir.add_symbols({params[0]: i})
-        ir.visit(tree)
+            # Compute global-global interactions
+            global_global_interactions = GlobalGlobalInteraction(sim, func.__name__, nparams)
+            for interaction_data in global_global_interactions:
+                ir = BuildParticleIR(sim, symbols, parameters)
+                ir.add_symbols({
+                    params[0]: interaction_data.i(),
+                    params[1]: interaction_data.j(),
+                    '__i__': interaction_data.i(),
+                    '__j__': interaction_data.j(),
+                    '__delta__': interaction_data.delta(),
+                    '__squared_distance__': interaction_data.squared_distance(),
+                    '__penetration_depth__': interaction_data.penetration_depth(),
+                    '__contact_point__': interaction_data.contact_point(),
+                    '__contact_normal__': interaction_data.contact_normal()
+                })
+                ir.visit(tree)
 
-    if sim._generate_whole_program:
-        sim.build_setup_module_with_statements()
-    else:
-        sim.build_user_defined_function()
+            
+            
+    # User defined functions are wrapped inside seperate interface modules here.
+    # The udf's have the same name as their interface module but they get implemented in the pairs::internal scope.
+    sim.build_interface_module()  
     
diff --git a/src/pairs/mapping/keywords.py b/src/pairs/mapping/keywords.py
index 51c255de3767f567182afa8dc5616e5f85151604..f3959e0b1c746b9a82f6df072927cc3852c9b13f 100644
--- a/src/pairs/mapping/keywords.py
+++ b/src/pairs/mapping/keywords.py
@@ -13,7 +13,7 @@ from pairs.ir.types import Types
 from pairs.ir.print import Print
 from pairs.ir.vectors import Vector, ZeroVector
 from pairs.sim.shapes import Shapes
-
+from pairs.sim.flags import Flags
 
 class Keywords:
     def __init__(self, sim):
@@ -47,12 +47,42 @@ class Keywords:
         assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
         return ScalarOp.cmp(self.sim.particle_shape[particle_id], Shapes.Sphere)
 
+    def keyword_is_box(self, args):
+        assert len(args) == 1, "is_box() keyword requires one parameter."
+        particle_id = args[0]
+        assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
+        return ScalarOp.cmp(self.sim.particle_shape[particle_id], Shapes.Box)
+
     def keyword_is_halfspace(self, args):
         assert len(args) == 1, "is_sphere() keyword requires one parameter."
         particle_id = args[0]
         assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
         return ScalarOp.cmp(self.sim.particle_shape[particle_id], Shapes.Halfspace)
 
+    def keyword_is_infinite(self, args):
+        assert len(args) == 1, "is_infinite() keyword requires one parameter."
+        particle_id = args[0]
+        assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
+        return self.sim.particle_flags[particle_id] & Flags.Infinite
+
+    def keyword_is_ghost(self, args):
+        assert len(args) == 1, "is_ghost() keyword requires one parameter."
+        particle_id = args[0]
+        assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
+        return self.sim.particle_flags[particle_id] & Flags.Ghost
+    
+    def keyword_is_fixed(self, args):
+        assert len(args) == 1, "is_fixed() keyword requires one parameter."
+        particle_id = args[0]
+        assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
+        return self.sim.particle_flags[particle_id] & Flags.Fixed
+
+    def keyword_is_global(self, args):
+        assert len(args) == 1, "is_global() keyword requires one parameter."
+        particle_id = args[0]
+        assert particle_id.type() == Types.Int32, "Particle ID must be an integer."
+        return self.sim.particle_flags[particle_id] & Flags.Global
+
     def keyword_select(self, args):
         assert len(args) == 3, "select() keyword requires three parameters!"
         return Select(self.sim, args[0], args[1], args[2])
@@ -125,6 +155,10 @@ class Keywords:
         assert len(args) == 0, "zero_vector() keyword requires no parameter."
         return ZeroVector(self.sim)
 
+    def keyword_vector(self, args):
+        assert len(args) == self.sim.ndims(), "vector()"
+        return Vector(self.sim, [args[i] for i in range(self.sim.ndims())])
+
     def keyword_transposed(self, args):
         assert len(args) == 1, "transposed() keyword requires one parameter."
         matrix = args[0]
@@ -151,13 +185,18 @@ class Keywords:
                                   inv_det * ((matrix[3] * matrix[7]) - (matrix[4] * matrix[6])),
                                   inv_det * ((matrix[6] * matrix[1]) - (matrix[7] * matrix[0])),
                                   inv_det * ((matrix[0] * matrix[4]) - (matrix[1] * matrix[3])) ])
-
+    
     def keyword_diagonal_matrix(self, args):
-        assert len(args) == 1, "diagonal_matrix() keyword requires one parameter!"
-        value = args[0]
-        nelems = Types.number_of_elements(self.sim, Types.Matrix)
-        return Matrix(self.sim, [value if i % (self.sim.ndims() + 1) == 0 else 0.0 \
-                                 for i in range(nelems)])
+        assert len(args) == 1 or len(args) == self.sim.ndims(), f"diagonal_matrix() keyword requires 1 or {self.sim.ndims()} parameters!"
+        if len(args) == 1:
+            value = args[0]
+            nelems = Types.number_of_elements(self.sim, Types.Matrix)
+            return Matrix(self.sim, [value if i % (self.sim.ndims() + 1) == 0 else 0.0 \
+                                    for i in range(nelems)])
+        elif len(args) == self.sim.ndims(): 
+            nelems = Types.number_of_elements(self.sim, Types.Matrix)
+            return Matrix(self.sim, [args[i % self.sim.ndims()] if i % (self.sim.ndims() + 1) == 0 else 0.0 \
+                                    for i in range(nelems)])
 
     def keyword_matrix_multiplication(self, args):
         assert len(args) == 2, "matrix_multiplication() keyword requires two parameters!"
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index 4038f48fb628f6c00c55c5cf535b43e8ce1a9af3..24b7329622c14429a5652ddb9208f6d94849bf43 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -132,7 +132,7 @@ class BuildCellLists(Lowerable):
         for i in ParticleFor(self.sim, local_only=False):
             flat_index = self.sim.add_temp_var(0)
 
-            for _ in Filter(self.sim, ASTTerm.not_op(particle_flags[i] & Flags.Infinite)):
+            for _ in Filter(self.sim, ASTTerm.not_op(particle_flags[i] & (Flags.Infinite | Flags.Global))):
                 cell_index = [
                     Cast.int(self.sim,
                         (positions[i][dim] - (dom_part.min(dim) - spacing[dim])) / spacing[dim]) \
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index 485e1a6cbb8c8fe54dabdadd7efd0e7368c3d23b..e72e445bf5c3f029c0b112aacbe8d125fdc53a56 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -158,30 +158,30 @@ class BlockForest:
         Assign(self.sim, self.rank, Call_Int(self.sim, "pairs_runtime->getDomainPartitioner()->getRank", []))
         Assign(self.sim, self.nranks, Call_Int(self.sim, "pairs_runtime->getNumberOfNeighborRanks", []))
 
-        for _ in Filter(self.sim, ScalarOp.neq(self.nranks, 0)):
-            Assign(self.sim, self.ntotal_aabbs, Call_Int(self.sim, "pairs_runtime->getNumberOfNeighborAABBs", []))
-
-            for _ in Filter(self.sim, self.nranks_capacity < self.nranks):
-                Assign(self.sim, self.nranks_capacity, self.nranks + 10)
-                self.ranks.realloc()
-                self.naabbs.realloc()
-                self.aabb_offsets.realloc()
-
-            for _ in Filter(self.sim, self.aabb_capacity < self.ntotal_aabbs):
-                Assign(self.sim, self.aabb_capacity, self.ntotal_aabbs + 20)
-                self.aabbs.realloc()
-            
-            CopyArray(self.sim, self.ranks, Contexts.Host, Actions.WriteOnly, self.nranks)
-            CopyArray(self.sim, self.naabbs, Contexts.Host, Actions.WriteOnly, self.nranks)
-            CopyArray(self.sim, self.aabb_offsets, Contexts.Host, Actions.WriteOnly, self.nranks)
-            CopyArray(self.sim, self.aabbs, Contexts.Host, Actions.WriteOnly, self.ntotal_aabbs * 6)
-            CopyArray(self.sim, self.subdom, Contexts.Host, Actions.WriteOnly)
-
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['ranks', self.ranks, self.nranks])
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['naabbs', self.naabbs, self.nranks])
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabb_offsets', self.aabb_offsets, self.nranks])
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabbs', self.aabbs, self.ntotal_aabbs * 6])
-            Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
+        # for _ in Filter(self.sim, ScalarOp.neq(self.nranks, 0)): # TODO: Test different block configs with PBC
+        Assign(self.sim, self.ntotal_aabbs, Call_Int(self.sim, "pairs_runtime->getNumberOfNeighborAABBs", []))
+
+        for _ in Filter(self.sim, self.nranks_capacity < self.nranks):
+            Assign(self.sim, self.nranks_capacity, self.nranks + 10)
+            self.ranks.realloc()
+            self.naabbs.realloc()
+            self.aabb_offsets.realloc()
+
+        for _ in Filter(self.sim, self.aabb_capacity < self.ntotal_aabbs):
+            Assign(self.sim, self.aabb_capacity, self.ntotal_aabbs + 20)
+            self.aabbs.realloc()
+        
+        CopyArray(self.sim, self.ranks, Contexts.Host, Actions.WriteOnly, self.nranks)
+        CopyArray(self.sim, self.naabbs, Contexts.Host, Actions.WriteOnly, self.nranks)
+        CopyArray(self.sim, self.aabb_offsets, Contexts.Host, Actions.WriteOnly, self.nranks)
+        CopyArray(self.sim, self.aabbs, Contexts.Host, Actions.WriteOnly, self.ntotal_aabbs * 6)
+        CopyArray(self.sim, self.subdom, Contexts.Host, Actions.WriteOnly)
+
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['ranks', self.ranks, self.nranks])
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['naabbs', self.naabbs, self.nranks])
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabb_offsets', self.aabb_offsets, self.nranks])
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabbs', self.aabbs, self.ntotal_aabbs * 6])
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
         
         if isinstance(self.sim.grid, MutableGrid):
             for d in range(self.sim.dims):
diff --git a/src/pairs/sim/global_interaction.py b/src/pairs/sim/global_interaction.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f213ec7aebf9174ea8fa11748f422152830fa7
--- /dev/null
+++ b/src/pairs/sim/global_interaction.py
@@ -0,0 +1,265 @@
+
+from pairs.ir.assign import Assign
+from pairs.ir.scalars import ScalarOp
+from pairs.ir.block import pairs_inline, pairs_device_block, pairs_host_block
+from pairs.ir.branches import Filter
+from pairs.ir.loops import For, ParticleFor
+from pairs.ir.types import Types
+from pairs.ir.device import CopyArray
+from pairs.ir.contexts import Contexts
+from pairs.ir.actions import Actions
+from pairs.ir.sizeof import Sizeof
+from pairs.ir.functions import Call_Void
+from pairs.ir.cast import Cast
+from pairs.sim.flags import Flags
+from pairs.sim.lowerable import Lowerable
+from pairs.sim.interaction import ParticleInteraction
+
+
+class GlobalLocalInteraction(ParticleInteraction):
+    def __init__(self, sim, module_name, nbody, cutoff_radius=None, use_cell_lists=False):
+        super().__init__(sim, module_name, nbody, cutoff_radius, use_cell_lists)
+
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name(f"{self.module_name}_global_local_interactions")
+        if self.sim._target.is_gpu():
+            first_cell_bytes = self.sim.add_temp_var(0)
+            Assign(self.sim, first_cell_bytes, self.cell_lists.cell_capacity * Sizeof(self.sim, Types.Int32))
+            CopyArray(self.sim, self.cell_lists.cell_sizes, Contexts.Host, Actions.ReadOnly, first_cell_bytes)
+        
+        for ishape in range(self.maxs):
+            if self.include_shape(ishape):
+                # Loop over the global cell
+                for p in For(self.sim, 0, self.cell_lists.cell_sizes[0]):
+                    i = self.cell_lists.cell_particles[0][p]
+                    # TODO: Skip if the bounding box of the global body doesn't intersect the subdom of this rank
+                    for _ in Filter(self.sim, ScalarOp.and_op(
+                        ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
+                        self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global))):
+                        for jshape in range(self.maxs):
+                            if self.include_interaction(ishape, jshape):
+                                # Globals are presenet in all ranks so they should not interact with ghosts
+                                # TODO: Make this loop the kernel candiate and reduce forces on the global body
+                                for j in ParticleFor(self.sim):
+                                    # Here we make make sure not to interact with other global bodies, otherwise
+                                    # their contributions will get reduced again over all ranks
+                                    for _ in Filter(self.sim, ScalarOp.and_op(
+                                        ScalarOp.cmp(self.sim.particle_shape[j], self.sim.get_shape_id(jshape)),
+                                        ScalarOp.not_op(self.sim.particle_flags[j] & (Flags.Infinite | Flags.Global)))):
+                                        for _ in Filter(self.sim, ScalarOp.neq(i, j)):
+                                            self.compute_interaction(i, j, ishape, jshape)
+
+
+class GlobalGlobalInteraction(ParticleInteraction):
+    def __init__(self, sim, module_name, nbody, cutoff_radius=None, use_cell_lists=False):
+        super().__init__(sim, module_name, nbody, cutoff_radius, use_cell_lists)
+
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name(f"{self.module_name}_global_global_interactions")
+        if self.sim._target.is_gpu():
+            first_cell_bytes = self.sim.add_temp_var(0)
+            Assign(self.sim, first_cell_bytes, self.cell_lists.cell_capacity * Sizeof(self.sim, Types.Int32))
+            CopyArray(self.sim, self.cell_lists.cell_sizes, Contexts.Host, Actions.ReadOnly, first_cell_bytes)
+        
+        for ishape in range(self.maxs):
+            if self.include_shape(ishape):
+                # Loop over the global cell
+                for p in For(self.sim, 0, self.cell_lists.cell_sizes[0]):
+                    i = self.cell_lists.cell_particles[0][p]
+                    for _ in Filter(self.sim, ScalarOp.and_op(
+                        ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
+                        self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global))):
+                        for jshape in range(self.maxs):
+                            if self.include_interaction(ishape, jshape):
+                                # Loop over the global cell
+                                for q in For(self.sim, 0, self.cell_lists.cell_sizes[0]):
+                                    j = self.cell_lists.cell_particles[0][q]
+                                    # Here we only compute interactions with other global bodies
+                                    for _ in Filter(self.sim, ScalarOp.and_op(
+                                        ScalarOp.cmp(self.sim.particle_shape[j], self.sim.get_shape_id(jshape)),
+                                        (self.sim.particle_flags[j] & (Flags.Infinite | Flags.Global)))):
+                                        for _ in Filter(self.sim, ScalarOp.neq(i, j)):
+                                            self.compute_interaction(i, j, ishape, jshape)
+
+class GlobalReduction:
+    def __init__(self, sim, module_name, particle_interaction):
+        self.sim = sim
+        self.module_name            = module_name
+        self.particle_interaction   = particle_interaction
+        self.nglobal_red            = sim.add_var('nglobal_red', Types.Int32)               # Number of global particles that need reduction
+        self.nglobal_capacity       = sim.add_var('nglobal_capacity', Types.Int32, 64)
+        self.global_elem_capacity   = sim.add_var('global_elem_capacity', Types.Int32, 100)
+        self.red_buffer             = sim.add_array('red_buffer', [self.nglobal_capacity, self.global_elem_capacity], Types.Real, arr_sync=False) 
+        self.intermediate_buffer    = sim.add_array('intermediate_buffer', [self.nglobal_capacity, self.global_elem_capacity], Types.Real, arr_sync=False)
+        self.sorted_idx             = sim.add_array('sorted_idx', [self.nglobal_capacity], Types.Int32, arr_sync=False)
+        self.unsorted_idx           = sim.add_array('unsorted_idx', [self.nglobal_capacity], Types.Int32, arr_sync=False)
+        self.removed_idx            = sim.add_array('removed_idx', [self.nglobal_capacity], Types.Boolean, arr_sync=False)
+
+        self.red_props = set()
+        for ishape in range(self.sim.max_shapes()):
+            for jshape in range(self.sim.max_shapes()):
+                if self.particle_interaction.include_interaction(ishape, jshape):
+                    for app in self.particle_interaction.apply_list[ishape*self.sim.max_shapes() + jshape]:
+                        self.red_props.add(app.prop())
+
+        # self.sim.add_statement(self)
+
+    # @pairs_inline
+    # def lower(self):
+    #     SortGlobals(self)
+    #     PackGlobals(self, self.intermediate_buffer)
+    #     ResetReductionProps(self)
+    #     GlobalLocalInteraction(self)
+    #     PackGlobals(self, self.red_buffer)
+    #     ReduceGlobals(self)
+    #     UnpackGlobals(self)
+    #     GlobalGlobalInteraction(self)
+
+    def global_particles(self):
+        for p in For(self.sim, 0, self.sim.cell_lists.cell_sizes[0]):
+            i = self.sim.cell_lists.cell_particles[0][p]
+            for ishape in range(self.sim.max_shapes()):
+                if self.particle_interaction.include_shape(ishape):
+                    for _ in Filter(self.sim, ScalarOp.and_op(
+                        ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
+                        self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global))):
+                        yield i
+
+    def get_elems_per_particle(self):
+        return sum([Types.number_of_elements(self.sim, p.type()) for p in self.red_props])
+    
+
+class SortGlobals(Lowerable):
+    def __init__(self, global_reduction):
+        super().__init__(global_reduction.sim)
+        self.global_reduction = global_reduction
+        self.sim.add_statement(self)
+
+    @pairs_host_block
+    def lower(self):
+        self.sim.module_name(f"{self.global_reduction.module_name}_sort_globals")
+        nglobal_capacity    = self.global_reduction.nglobal_capacity
+        nglobal_red         = self.global_reduction.nglobal_red
+        unsorted_idx        = self.global_reduction.unsorted_idx
+        sorted_idx          = self.global_reduction.sorted_idx
+        removed_idx         = self.global_reduction.removed_idx
+        uid                 = self.sim.particle_uid
+        self.sim.check_resize(nglobal_capacity, nglobal_red)
+
+        Assign(self.sim, nglobal_red, 0)
+        for i in self.global_reduction.global_particles():
+            Assign(self.sim, unsorted_idx[nglobal_red], i)
+            Assign(self.sim, sorted_idx[nglobal_red], 0)
+            Assign(self.sim, removed_idx[nglobal_red], 0)
+            Assign(self.sim, nglobal_red, nglobal_red +1)
+
+        min_uid = self.sim.add_temp_var(0, Types.UInt64)
+        min_idx = self.sim.add_temp_var(0)
+
+        # Here we sort indices of global bodies with respect to their uid's.
+        # The sorted uid's will be in identical order on all ranks. This ensures that the
+        # reduced properties are mapped correctly to each global body during inplace reduction.
+        for i in For(self.sim, 0, nglobal_red):
+            Assign(self.sim, min_uid, -1)   # TODO: Lit max: UINT64_MAX
+            Assign(self.sim, min_idx, 0)
+            for j in For(self.sim, 0, nglobal_red):
+                for _ in Filter(self.sim, ScalarOp.and_op(uid[unsorted_idx[j]] < min_uid,
+                                                          ScalarOp.not_op(removed_idx[j]))):
+                    Assign(self.sim, min_uid, uid[unsorted_idx[j]])
+                    Assign(self.sim, min_idx, j)
+
+            Assign(self.sim, sorted_idx[i], unsorted_idx[min_idx])
+            Assign(self.sim, removed_idx[min_idx], 1)
+
+
+class PackGlobals(Lowerable):
+    def __init__(self, global_reduction, save_state=True):
+        super().__init__(global_reduction.sim)
+        self.global_reduction = global_reduction
+        self.save_state = save_state
+        self.buffer = global_reduction.intermediate_buffer if save_state else global_reduction.red_buffer
+        self.sim.add_statement(self)
+
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name(f"{self.global_reduction.module_name}_pack_globals_{'intermediate' if self.save_state else 'reduce'}")
+        nglobal_red         = self.global_reduction.nglobal_red
+        sorted_idx          = self.global_reduction.sorted_idx
+        nelems_per_particle = self.global_reduction.get_elems_per_particle()
+        self.buffer.set_stride(1, nelems_per_particle)
+
+        for buffer_idx in For(self.sim, 0, nglobal_red):
+            i = sorted_idx[buffer_idx]
+            p_offset = 0
+            for p in self.global_reduction.red_props:
+                if not Types.is_scalar(p.type()):
+                    nelems = Types.number_of_elements(self.sim, p.type())
+                    for e in range(nelems):
+                        Assign(self.sim, self.buffer[buffer_idx][p_offset + e], p[i][e])
+
+                    p_offset += nelems
+                else:
+                    cast_fn = lambda x: Cast(self.sim, x, Types.Real) if p.type() != Types.Real else x
+                    Assign(self.sim, self.buffer[buffer_idx][p_offset], cast_fn(p[i]))
+                    p_offset += 1
+
+
+class ResetReductionProps(Lowerable):
+    def __init__(self, global_reduction):
+        super().__init__(global_reduction.sim)
+        self.global_reduction = global_reduction
+        self.sim.add_statement(self)
+
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name(f"{self.global_reduction.module_name}_reset_globals")
+        nglobal_red         = self.global_reduction.nglobal_red
+        sorted_idx          = self.global_reduction.sorted_idx
+
+        for buffer_idx in For(self.sim, 0, nglobal_red):
+            i = sorted_idx[buffer_idx]
+            for p in self.global_reduction.red_props:
+                Assign(self.sim, p[i], 0.0)
+
+class ReduceGlobals(Lowerable):
+    def __init__(self, global_reduction):
+        super().__init__(global_reduction.sim)
+        self.global_reduction = global_reduction
+        self.sim.add_statement(self)
+        
+    @pairs_inline
+    def lower(self):
+        nelems_total = self.global_reduction.nglobal_red * self.global_reduction.get_elems_per_particle() 
+        Call_Void( self.sim, "pairs_runtime->allReduceInplaceSum", [self.global_reduction.red_buffer, nelems_total])
+
+
+class UnpackGlobals(Lowerable):
+    def __init__(self, global_reduction):
+        super().__init__(global_reduction.sim)
+        self.global_reduction = global_reduction
+        self.sim.add_statement(self)
+
+    @pairs_device_block
+    def lower(self):
+        self.sim.module_name(f"{self.global_reduction.module_name}_unpack_globals")
+        nglobal_red = self.global_reduction.nglobal_red
+        sorted_idx  = self.global_reduction.sorted_idx
+        red_buffer  = self.global_reduction.red_buffer
+        intermediate_buffer  = self.global_reduction.intermediate_buffer
+
+        for buffer_idx in For(self.sim, 0, nglobal_red):
+            i = sorted_idx[buffer_idx]
+            p_offset = 0
+            for p in self.global_reduction.red_props:
+                if not Types.is_scalar(p.type()):
+                    nelems = Types.number_of_elements(self.sim, p.type())
+                    for e in range(nelems):
+                        Assign(self.sim, p[i][e], red_buffer[buffer_idx][p_offset + e] + intermediate_buffer[buffer_idx][p_offset + e])
+
+                    p_offset += nelems
+                else:                    
+                    cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Real else x
+                    Assign(self.sim, p[i], cast_fn(red_buffer[buffer_idx][p_offset] + intermediate_buffer[buffer_idx][p_offset + e]))
+                    p_offset += 1
\ No newline at end of file
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index 34bfa58dfb0d15b97b0ec87bff06ec8ef9d1f729..f19289eb71e809dc531a88fd473cf8e0c1bfce0d 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -1,13 +1,14 @@
 from pairs.ir.assign import Assign
 from pairs.ir.ast_term import ASTTerm
 from pairs.ir.scalars import ScalarOp
-from pairs.ir.block import Block, pairs_inline
-from pairs.ir.branches import Filter
+from pairs.ir.block import Block, pairs_device_block
+from pairs.ir.branches import Filter, Branch
 from pairs.ir.loops import For, ParticleFor
-from pairs.ir.math import Sqrt
+from pairs.ir.math import Sqrt, Abs, Min, Sign
 from pairs.ir.select import Select
 from pairs.ir.types import Types
 from pairs.ir.vectors import Vector
+from pairs.ir.print import Print
 from pairs.sim.flags import Flags
 from pairs.sim.lowerable import Lowerable
 from pairs.sim.shapes import Shapes
@@ -46,7 +47,8 @@ class NeighborFor:
         self.particle = particle
         self.cell_lists = cell_lists
         self.neighbor_lists = neighbor_lists
-        self.shapes = range(sim.max_shapes()) if shapes is None else shapes
+        # self.shapes = range(sim.max_shapes()) if shapes is None else shapes
+        self.shapes = [shapes]
 
     def __str__(self):
         return f"NeighborFor<{self.particle}>"
@@ -131,6 +133,7 @@ class NeighborFor:
 
 class InteractionData:
     def __init__(self, sim, shape):
+        self.sim = sim
         self._i = sim.add_symbol(Types.Int32)
         self._j = sim.add_symbol(Types.Int32)
         self._delta = sim.add_symbol(Types.Vector)
@@ -139,6 +142,8 @@ class InteractionData:
         self._contact_point = sim.add_symbol(Types.Vector)
         self._contact_normal = sim.add_symbol(Types.Vector)
         self._shape = shape
+        self.contact_threshold = 0.0
+        self.cutoff_condition = None
 
     def i(self):
         return self._i
@@ -163,133 +168,271 @@ class InteractionData:
 
     def shape(self):
         return self._shape
-
+    
+    def pointmass_pointmass(self, i, j, cutoff_radius):
+        position = self.sim.position()
+        delta = position[i] - position[j]
+        squared_distance = delta.x() * delta.x() + \
+                        delta.y() * delta.y() + \
+                        delta.z() * delta.z()
+        separation_dist = cutoff_radius * cutoff_radius
+        self.cutoff_condition = squared_distance < separation_dist
+
+        self.delta().assign(delta)
+        self.squared_distance().assign(squared_distance)
+
+    def sphere_halfspace(self, i, j):
+        position = self.sim.position()
+        radius = self.sim.property('radius')
+        normal = self.sim.property('normal')
+
+        d = normal[j][0] * position[j][0] + \
+            normal[j][1] * position[j][1] + \
+            normal[j][2] * position[j][2]
+
+        k = normal[j][0] * position[i][0] + \
+            normal[j][1] * position[i][1] + \
+            normal[j][2] * position[i][2]
+
+        penetration_depth = k - radius[i] - d
+        self.cutoff_condition = penetration_depth < self.contact_threshold
+        tmp = radius[i] + penetration_depth
+        contact_normal = normal[j]
+        contact_point = position[i] - Vector(self.sim, [tmp, tmp, tmp]) * normal[j]
+        
+        self.penetration_depth().assign(penetration_depth)
+        self.contact_point().assign(contact_point)
+        self.contact_normal().assign(contact_normal)
+
+    def sphere_sphere(self, i, j):
+        position = self.sim.position()
+        radius = self.sim.property('radius')
+        delta = position[i] - position[j]
+        squared_distance = delta.x() * delta.x() + \
+                        delta.y() * delta.y() + \
+                        delta.z() * delta.z()
+        separation_dist = radius[i] + radius[j] + self.contact_threshold
+        self.cutoff_condition = squared_distance < separation_dist * separation_dist
+        distance = Sqrt(self.sim, squared_distance)
+        penetration_depth = distance - radius[i] - radius[j]
+        contact_normal = delta * (1.0 / distance)
+        k = radius[j] + 0.5 * penetration_depth
+        contact_point = position[j] + contact_normal * k
+
+        self.delta().assign(delta)
+        self.squared_distance().assign(squared_distance)
+        self.penetration_depth().assign(penetration_depth)
+        self.contact_point().assign(contact_point)
+        self.contact_normal().assign(contact_normal)
+
+    def sphere_box(self, i, j, s_relative=True):
+        s = i   # Sphere
+        b = j   # Box
+        position = self.sim.position()
+        edge_length = self.sim.property('edge_length')
+        rotation_matrix = self.sim.property('rotation_matrix')
+        radius = self.sim.property('radius')
+        
+        l = edge_length[b] * 0.5
+        rm = rotation_matrix[b]
+        delta = position[s] - position[b]
+        
+        # Distance to sphere in the box coodinate: p = rm.T * delta 
+        p0 = delta[0]*rm[0] + delta[1]*rm[3] + delta[2]*rm[6]
+        p1 = delta[0]*rm[1] + delta[1]*rm[4] + delta[2]*rm[7]
+        p2 = delta[0]*rm[2] + delta[1]*rm[5] + delta[2]*rm[8]
+
+        # Check if the sphere is outside the box. If yes, clamp p to box bounds
+        outside0 =  ScalarOp.or_op( p0 <  -l[0], p0 >  l[0])
+        p0 = ScalarOp.and_op(p0 >= -l[0], p0 <= l[0]) * p0 + (p0 < -l[0])*(-l[0]) + (p0 > l[0])*l[0]
+        outside1 =  ScalarOp.or_op( p1 <  -l[1], p1 >  l[1])
+        p1 = ScalarOp.and_op(p1 >= -l[1], p1 <= l[1]) * p1 + (p1 < -l[1])*(-l[1]) + (p1 > l[1])*l[1]
+        outside2 =  ScalarOp.or_op( p2 <  -l[2], p2 >  l[2])
+        p2 = ScalarOp.and_op(p2 >= -l[2], p2 <= l[2]) * p2 + (p2 < -l[2])*(-l[2]) + (p2 > l[2])*l[2]
+
+        self.cutoff_condition = self.sim.add_temp_var(0)
+        squared_distance = self.sim.add_temp_var(0.0)
+        penetration_depth = self.sim.add_temp_var(0.0)
+        contact_point = self.sim.add_temp_var([0.0, 0.0, 0.0])
+        contact_normal = self.sim.add_temp_var([0.0, 0.0, 0.0])
+
+        for outside in Branch(self.sim, ScalarOp.or_op(ScalarOp.or_op(outside0, outside1), outside2)):
+            if outside:
+                # Transfrom p to global coordinate: q = rm * p
+                q = Vector(self.sim, [  rm[0]*p0 + rm[1]*p1 + rm[2]*p2, 
+                                        rm[3]*p0 + rm[4]*p1 + rm[5]*p2,
+                                        rm[6]*p0 + rm[7]*p1 + rm[8]*p2])
+                
+                # Normal away from the box
+                n = delta - q
+                Assign(self.sim, squared_distance, n[0]*n[0] + n[1]*n[1] + n[2]*n[2])
+                distance = Sqrt(self.sim, squared_distance)
+                Assign(self.sim, penetration_depth, distance - radius[s])
+                Assign(self.sim, contact_point, position[b] + q)
+                Assign(self.sim, contact_normal, n * (1.0 / distance))
+                Assign(self.sim, self.cutoff_condition, penetration_depth < self.contact_threshold)
+            else:
+                # Print(self.sim, "Sphere", s, "(", position[s][0], position[s][1], position[s][2], ") is inside box", b)
+                # If sphere is inside the box find the closest face
+                dist0 = l[0] - Abs(self.sim, p0)
+                dist1 = l[1] - Abs(self.sim, p1)
+                dist2 = l[2] - Abs(self.sim, p2)
+                mindist = Min(self.sim, Min(self.sim, dist0, dist1), dist2)
+                face = Select(self.sim, ScalarOp.cmp(mindist, dist0), 0, 
+                              Select(self.sim,ScalarOp.cmp(mindist, dist1), 1, 2))
+
+                n = self.sim.add_temp_var([0.0, 0.0, 0.0])
+                # In this case, the normal away from the box is an axis-aligned unit vector in the box coordinate 
+                # Assign(self.sim, n[face], Sign(self.sim, p[face]))   
+                # FIXME: Issue with index generation (a Select node can't be an index for VectorAccess)
+                #        hence the 3 lines below:
+                Assign(self.sim, n[0], Select(self.sim, ScalarOp.cmp(face,0), Sign(self.sim, p0), 0.0))
+                Assign(self.sim, n[1], Select(self.sim, ScalarOp.cmp(face,1), Sign(self.sim, p1), 0.0))
+                Assign(self.sim, n[2], Select(self.sim, ScalarOp.cmp(face,2), Sign(self.sim, p2), 0.0))
+
+                # Transofram n to global coordinates
+                Assign(self.sim, contact_normal[0], rm[0]*n[0] + rm[1]*n[1] + rm[2]*n[2]) 
+                Assign(self.sim, contact_normal[1], rm[3]*n[0] + rm[4]*n[1] + rm[5]*n[2]) 
+                Assign(self.sim, contact_normal[2], rm[6]*n[0] + rm[7]*n[1] + rm[8]*n[2]) 
+
+                Assign(self.sim, penetration_depth, - mindist - radius[s])
+                Assign(self.sim, contact_point, position[s])
+                Assign(self.sim, self.cutoff_condition, penetration_depth < self.contact_threshold)
+
+        self.delta().assign(delta)
+        self.squared_distance().assign(squared_distance)
+        self.penetration_depth().assign(penetration_depth)
+        self.contact_point().assign(contact_point)
+        # Contact normal is always computed towards the sphere
+        self.contact_normal().assign(contact_normal if s_relative else -contact_normal)
 
 class ParticleInteraction(Lowerable):
-    def __init__(self, sim, nbody, cutoff_radius, use_cell_lists=False, split_kernels=False):
+    def __init__(self, sim, module_name,  nbody, cutoff_radius=None, use_cell_lists=False):
         super().__init__(sim)
+        self.module_name = module_name
         self.nbody = nbody
-        self.cutoff_radius = cutoff_radius
         self.contact_threshold = 0.0
         self.use_cell_lists = use_cell_lists
-        self.split_kernels = split_kernels
-        self.nkernels = sim.max_shapes() if split_kernels else 1
+        self.maxs = self.sim.max_shapes()
         self.interactions_data = {}
-        self.blocks = [Block(sim, []) for _ in range(sim.max_shapes())]
-        self.apply_list = [set() for _ in range(self.nkernels)]
+        self.cutoff_radius = cutoff_radius
+        if any(self.sim.get_shape_id(s)==Shapes.PointMass for s in range(self.maxs)): 
+            assert cutoff_radius is not None
+
+        # We reserve n*n blocks and apply_lists for n shapes present in the system, but we only use the included i-j interactions
+        self.blocks = [Block(sim, []) for _ in range(self.maxs * self.maxs)]
+        self.apply_list = [set() for _ in range(self.maxs * self.maxs)]
         self.active_block = None
+        self.cell_lists = self.sim.cell_lists
 
     def add_statement(self, stmt):
         self.active_block.add_statement(stmt)
 
+    # Included interactions
+    # -------------------------------------------------------------------
+    #   i \ j   |   Sphere  |   Halfspace   |   PointMass   |  Box  |
+    # -------------------------------------------------------------------
+    # Sphere    |     1     |       1       |       0       |   1   |
+    # Halfspace |     0     |       0       |       0       |   0   |
+    # PointMass |     0     |       0       |       1       |   0   |
+    # Box       |     1     |       0       |       0       |   0   |
+
+    def include_interaction(self, ishape, jshape):
+        id_i = self.sim.get_shape_id(ishape)
+        id_j = self.sim.get_shape_id(jshape)
+        return  (id_i == Shapes.PointMass and id_j == Shapes.PointMass) or \
+                (id_i == Shapes.Sphere and id_j == Shapes.Sphere) or \
+                (id_i == Shapes.Sphere and id_j == Shapes.Halfspace) or \
+                (id_i == Shapes.Sphere and id_j == Shapes.Box) or \
+                (id_i == Shapes.Box and id_j == Shapes.Sphere)                     
+                    
+    # Included kernels
+    def include_shape(self, ishape):
+        id_i = self.sim.get_shape_id(ishape)
+        return  (id_i == Shapes.PointMass) or \
+                (id_i == Shapes.Sphere) or \
+                (id_i == Shapes.Box)
+
     def __iter__(self):
         self.sim.add_statement(self)
         self.sim.enter(self)
 
         # Neighbors vary across iterations
-        for shape in range(self.sim.max_shapes()):
-            apply_list_id = shape if self.split_kernels else 0
-            self.sim.use_apply_list(self.apply_list[apply_list_id])
-            self.active_block = self.blocks[shape]
-            self.interactions_data[shape] = InteractionData(self.sim, shape)
-            yield self.interactions_data[shape]
-            self.sim.release_apply_list()
+        for ishape in range(self.sim.max_shapes()):
+            for jshape in range(self.sim.max_shapes()):
+                if self.include_interaction(ishape, jshape):
+                    apply_list_id = ishape*self.maxs + jshape
+                    self.sim.use_apply_list(self.apply_list[apply_list_id])
+                    self.active_block = self.blocks[ishape*self.maxs + jshape]
+                    self.interactions_data[ishape*self.maxs + jshape] = InteractionData(self.sim, ishape*self.maxs + jshape)
+                    yield self.interactions_data[ishape*self.maxs + jshape]
+                    self.sim.release_apply_list()
 
         self.sim.leave()
         self.active_block = None
 
-    @pairs_inline
+    def apply_reductions(self, i, ishape, jshape):
+        prop_reductions = {}
+        for app in self.apply_list[ishape*self.maxs + jshape]:
+            prop = app.prop()
+            reduction = app.reduction_variable()
+            if prop not in prop_reductions:
+                prop_reductions[prop] = reduction
+            else:
+                prop_reductions[prop] = prop_reductions[prop] + reduction
+
+        for prop, reduction in prop_reductions.items():
+            Assign(self.sim, prop[i], prop[i] + reduction)
+
+    def compute_interaction(self, i, j, ishape, jshape):
+        interaction_data = self.interactions_data[ishape*self.maxs + jshape]
+        interaction_data.i().assign(i)
+        interaction_data.j().assign(j)
+        ishape_id = self.sim.get_shape_id(ishape)
+        jshape_id = self.sim.get_shape_id(jshape)
+        
+        if ishape_id == Shapes.PointMass and jshape_id == Shapes.PointMass:
+            interaction_data.pointmass_pointmass(i, j, self.cutoff_radius)
+
+        if ishape_id == Shapes.Sphere and jshape_id == Shapes.Sphere:
+            interaction_data.sphere_sphere(i, j)
+
+        if ishape_id == Shapes.Sphere and jshape_id == Shapes.Halfspace:
+            interaction_data.sphere_halfspace(i, j)
+
+        if ishape_id == Shapes.Sphere and jshape_id == Shapes.Box:
+            interaction_data.sphere_box(i, j)
+            
+        if ishape_id == Shapes.Box and jshape_id == Shapes.Sphere:
+            interaction_data.sphere_box(j, i, False)
+
+        # Apply reductions for this i-j interaction:
+        # -------------------------------------------------------------
+        for app in self.apply_list[ishape*self.maxs + jshape]:
+            app.add_reduction_variable()
+
+        # The i-j block is executed only if the cutoff_condition of the i-j interaction is met
+        self.sim.add_statement(Filter(self.sim, interaction_data.cutoff_condition, self.blocks[ishape*self.maxs + jshape]))
+        self.apply_reductions(i, ishape, jshape)
+
+    @pairs_device_block
     def lower(self):
+        self.sim.module_name(f"{self.module_name}_local_interactions")
         if self.nbody == 2:
-            position = self.sim.position()
-            cell_lists = self.sim.cell_lists
             neighbor_lists = None if self.use_cell_lists else self.sim.neighbor_lists
-
-            for kernel in range(self.nkernels):
-                for i in ParticleFor(self.sim):
-                    for _ in Filter(self.sim, ScalarOp.cmp(self.sim.particle_flags[i] & Flags.Fixed, 0)):
-                        for app in self.apply_list[kernel]:
-                            app.add_reduction_variable()
-
-                        shapes = [kernel] if self.split_kernels else None
-                        interaction = kernel
-                        for neigh in NeighborFor(self.sim, i, cell_lists, neighbor_lists, shapes):
-                            interaction_data = self.interactions_data[interaction]
-                            shape = interaction_data.shape()
-                            shape_id = self.sim.get_shape_id(shape)
-                            j = neigh.particle_index()
-
-                            if shape_id == Shapes.PointMass:
-                                delta = position[i] - position[j]
-                                squared_distance = delta.x() * delta.x() + \
-                                                   delta.y() * delta.y() + \
-                                                   delta.z() * delta.z()
-                                separation_dist = self.cutoff_radius * self.cutoff_radius
-                                cutoff_condition = squared_distance < separation_dist
-                                distance = Sqrt(self.sim, squared_distance)
-                                penetration_depth = None
-                                contact_normal = None
-                                contact_point = None
-
-                            elif shape_id == Shapes.Sphere:
-                                radius = self.sim.property('radius')
-                                delta = position[i] - position[j]
-                                squared_distance = delta.x() * delta.x() + \
-                                                   delta.y() * delta.y() + \
-                                                   delta.z() * delta.z()
-                                separation_dist = radius[i] + radius[j] + self.contact_threshold
-                                cutoff_condition = squared_distance < separation_dist * separation_dist
-                                distance = Sqrt(self.sim, squared_distance)
-                                penetration_depth = distance - radius[i] - radius[j]
-                                contact_normal = delta * (1.0 / distance)
-                                k = radius[j] + 0.5 * penetration_depth
-                                contact_point = position[j] + contact_normal * k
-
-                            elif shape_id == Shapes.Halfspace:
-                                radius = self.sim.property('radius')
-                                normal = self.sim.property('normal')
-
-                                d = normal[j][0] * position[j][0] + \
-                                    normal[j][1] * position[j][1] + \
-                                    normal[j][2] * position[j][2]
-
-                                k = normal[j][0] * position[i][0] + \
-                                    normal[j][1] * position[i][1] + \
-                                    normal[j][2] * position[i][2]
-
-                                penetration_depth = k - radius[i] - d
-                                cutoff_condition = penetration_depth < self.contact_threshold
-                                tmp = radius[i] + penetration_depth
-                                contact_normal = normal[j]
-                                contact_point = position[i] - Vector(self.sim, [tmp, tmp, tmp]) * normal[j]
-
-                            else:
-                                raise Exception("Invalid shape id.")
-
-                            interaction_data.i().assign(i)
-                            interaction_data.j().assign(j)
-                            interaction_data.delta().assign(delta)
-                            interaction_data.squared_distance().assign(squared_distance)
-                            interaction_data.penetration_depth().assign(penetration_depth)
-                            interaction_data.contact_point().assign(contact_point)
-                            interaction_data.contact_normal().assign(contact_normal)
-                            self.sim.add_statement(
-                                Filter(self.sim, cutoff_condition, self.blocks[shape]))
-                            interaction += 1
-
-                        prop_reductions = {}
-                        for app in self.apply_list[kernel]:
-                            prop = app.prop()
-                            reduction = app.reduction_variable()
-
-                            if prop not in prop_reductions:
-                                prop_reductions[prop] = reduction
-
-                            else:
-                                prop_reductions[prop] = prop_reductions[prop] + reduction
-
-                        for prop, reduction in prop_reductions.items():
-                            Assign(self.sim, prop[i], prop[i] + reduction)
+            for ishape in range(self.maxs):
+                if self.include_shape(ishape):
+                    # A kernel for each ishape
+                    for i in ParticleFor(self.sim):
+                        for _ in Filter(self.sim, ScalarOp.and_op(
+                            ScalarOp.cmp(self.sim.particle_shape[i], self.sim.get_shape_id(ishape)),
+                            ScalarOp.not_op(self.sim.particle_flags[i] & (Flags.Infinite | Flags.Global)))):
+                            for jshape in range(self.maxs):
+                                if self.include_interaction(ishape, jshape):
+                                    # Inner loops for each jshaped neighbor
+                                    for neigh in NeighborFor(self.sim, i, self.cell_lists, neighbor_lists, jshape):
+                                        j = neigh.particle_index()
+                                        self.compute_interaction(i, j, ishape, jshape)
 
         else:
             raise Exception("Interactions among more than two particles are currently not supported.")
diff --git a/src/pairs/sim/shapes.py b/src/pairs/sim/shapes.py
index 3f768c668e2c1f619159ff4e9b11e5b986b48363..a5f3a23b79303cd29737cc35ed7100c418fc1767 100644
--- a/src/pairs/sim/shapes.py
+++ b/src/pairs/sim/shapes.py
@@ -2,3 +2,4 @@ class Shapes:
     Sphere      =   0
     Halfspace   =   1
     PointMass   =   2
+    Box         =   3
\ No newline at end of file
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index f7360b4e4e8d28c6f4a5407f5af9162188c0a890..432d7e5a61acc7d5000a7204f3570c02ef9b48a7 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -10,7 +10,7 @@ from pairs.ir.symbols import Symbol
 from pairs.ir.types import Types
 from pairs.ir.variables import Variables
 #from pairs.graph.graphviz import ASTGraph
-from pairs.mapping.funcs import compute, setup
+from pairs.mapping.funcs import compute
 from pairs.sim.arrays import DeclareArrays
 from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists, BuildCellNeighborLists
 from pairs.sim.comm import Comm, Synchronize, Borders, Exchange, ReverseComm
@@ -301,11 +301,11 @@ class Simulation:
         assert self.var(var_name) is None, f"Variable already defined: {var_name}"
         return self.vars.add(var_name, var_type, init_value, runtime)
 
-    def add_temp_var(self, init_value):
-        return self.vars.add_temp(init_value)
+    def add_temp_var(self, init_value, type=None):
+        return self.vars.add_temp(init_value, type)
 
-    def add_symbol(self, sym_type):
-        return Symbol(self, sym_type)
+    def add_symbol(self, sym_type, name=None):
+        return Symbol(self, sym_type, name)
 
     def var(self, var_name):
         return self.vars.find(var_name)
@@ -358,11 +358,9 @@ class Simulation:
         self.neighbor_lists = NeighborLists(self, self.cell_lists)
         return self.neighbor_lists
 
-    def compute(self, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False):
-        return compute(self, func, cutoff_radius, symbols, parameters, pre_step, skip_first)
+    def compute(self, func, cutoff_radius=None, symbols={}, parameters={}, compute_globals=False):
+        return compute(self, func, cutoff_radius, symbols, parameters, compute_globals)
 
-    def setup(self, func, symbols={}):
-        return setup(self, func, symbols)
 
     def init_block(self):
         """Initialize new block in this simulation instance"""
@@ -440,6 +438,16 @@ class Simulation:
                 user_defined=True)
         
 
+    def build_interface_module(self, run_on_device=False):
+        """Build a user-defined Module that will be callable seperately as part of the interface"""
+        Module(self, name=self._module_name,
+                block=Block(self, self._block),
+                resizes_to_check=self._resizes_to_check,
+                check_properties_resize=self._check_properties_resize,
+                run_on_device=run_on_device,
+                user_defined=False,
+                interface=True)
+        
     def capture_statements(self, capture=True):
         """When toggled, all constructed statements are captured and automatically added to the last initialized block"""
         self._capture_statements = capture
@@ -531,13 +539,6 @@ class Simulation:
 
     def generate_library(self):
         InterfaceModules(self).create_all()
-        
-        # User defined functions are wrapped inside seperate interface modules here.
-        # The udf's have the same name as their interface module but they get implemented in the pairs::internal scope.
-        for m in self.udf_module_list:
-            module = Module(self, name=m.name, block=Block(self, m), interface=True)
-            module._id = m._id
-
         Transformations(self.interface_modules(), self._target).apply_all()
 
         # Generate library
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index 733d5c10fbaec621d37db6a7009e64493f752719..837286506f288576b8ad82e456e7a4b09cc935a1 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -73,6 +73,7 @@ class Transformations:
 
     def add_device_copies(self):
         if self._target.is_gpu():
+            self.analysis().fetch_device_copies()
             self.apply(AddDeviceCopies(), [self._module_resizes])
             self.analysis().fetch_modules_references()
 
@@ -104,10 +105,11 @@ class Transformations:
         self.licm()
         self.modularize()
         self.add_device_kernels()
+        self.add_host_references_to_modules()
         self.add_device_copies()
         self.lower(True)
         self.add_expression_declarations()
-        self.add_host_references_to_modules()
+        # self.add_host_references_to_modules()
         self.add_device_references_to_modules()
         
         # TODO: Place stop timers before the function returns
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index d33f30ef174ba06b07d34908aff49926ea982fe2..7a01fee52a1fbdb362d843dc508b26155bfd1d24 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -36,15 +36,22 @@ class AddDeviceCopies(Mutator):
                 if isinstance(s, ModuleCall):
                     copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host
                     clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device
-                    new_stmts += [
-                        Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
-                    ]
+                    # new_stmts += [
+                    #     Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
+                    # ]
 
                     for array, action in s.module.arrays().items():
-                        new_stmts += [CopyArray(s.sim, array, copy_context, action)]
+                        # TODO: Add device copies only if they are not mannualy taken care of inside the module
+                        # if array not in s.module.device_copies():
+                            new_stmts += [CopyArray(s.sim, array, copy_context, action)]
+                            # TODO: Add copyToHost for host references in device modules
+                            # if array in s.module.host_references():
+                            #     new_stmts += [CopyArray(s.sim, array, Contexts.Host, action)]
 
                     for prop, action in s.module.properties().items():
                         new_stmts += [CopyProperty(s.sim, prop, copy_context, action)]
+                        # if prop in s.module.host_references():
+                        #     new_stmts += [CopyProperty(s.sim, prop, Contexts.Host, action)]
 
                     for fp, action in s.module.feature_properties().items():
                         new_stmts += [CopyFeatureProperty(s.sim, fp, copy_context, action)]
@@ -59,18 +66,20 @@ class AddDeviceCopies(Mutator):
                         for var, action in s.module.variables().items():
                             if action != Actions.ReadOnly and var.device_flag:
                                 new_stmts += [CopyVar(s.sim, var, Contexts.Device, action)]
+                            # if var not in s.module.device_copies() and var in s.module.host_references():
+                            #     new_stmts += [CopyVar(s.sim, var, Contexts.Host, action)]
 
-                    new_stmts += [
-                        Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
-                    ]
+                    # new_stmts += [
+                    #     Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
+                    # ]
 
                 new_stmts.append(s)
 
                 if isinstance(s, ModuleCall):
                     if s.module.run_on_device:
-                        new_stmts += [
-                            Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
-                        ]
+                        # new_stmts += [
+                        #     Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
+                        # ]
 
                         for var, action in s.module.variables().items():
                             if action != Actions.ReadOnly and var.device_flag:
@@ -78,9 +87,9 @@ class AddDeviceCopies(Mutator):
 
                         if self.module_resizes[s.module]:
                             new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, Actions.Ignore)]
-                        new_stmts += [
-                            Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
-                        ]
+                        # new_stmts += [
+                        #     Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
+                        # ]
 
         ast_node.stmts = new_stmts
         return ast_node