From fd1df2c3c8eca41cdd3b94455f0108a73b6bc454 Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@a0522.nhr.fau.de>
Date: Sat, 15 Mar 2025 23:05:29 +0100
Subject: [PATCH] [Experimental] Add halo cells for ghost/exchange optimization

---
 examples/modular/sphere_box_global.cpp  |   4 +-
 runtime/domain/ParticleDataHandling.hpp |   2 +-
 runtime/domain/block_forest.cpp         |   8 +-
 runtime/domain/regular_6d_stencil.cpp   |   6 +
 runtime/timers.hpp                      |  13 ++-
 runtime/utility/dem_sc_grid.cpp         |   6 +
 src/pairs/__init__.py                   |   2 +-
 src/pairs/code_gen/interface.py         |   2 +-
 src/pairs/mapping/funcs.py              |   2 +-
 src/pairs/sim/cell_lists.py             |  63 ++++++++++-
 src/pairs/sim/comm.py                   |   5 +-
 src/pairs/sim/domain_partitioning.py    | 142 +++++++++++++++++++++++-
 src/pairs/sim/simulation.py             |   7 +-
 13 files changed, 240 insertions(+), 22 deletions(-)

diff --git a/examples/modular/sphere_box_global.cpp b/examples/modular/sphere_box_global.cpp
index 783bbaa..3c30738 100644
--- a/examples/modular/sphere_box_global.cpp
+++ b/examples/modular/sphere_box_global.cpp
@@ -41,7 +41,7 @@ int main(int argc, char **argv) {
 
     auto pairs_runtime = pairs_sim->getPairsRuntime();
 
-    pairs_runtime->initDomain(&argc, &argv, 0, 0, 0, 30, 30, 30, false, false, false, true); 
+    pairs_runtime->initDomain(&argc, &argv, 0, 0, 0, 30, 30, 30, true); 
     pairs_runtime->getDomainPartitioner()->initWorkloadBalancer(pairs::Hilbert, 100, 800);
 
     pairs::create_halfspace(pairs_runtime, 0,0,0,       1, 0, 0,    0, pairs::flags::INFINITE | pairs::flags::FIXED);
@@ -62,7 +62,7 @@ int main(int argc, char **argv) {
     
     // Use the diameter of small particles to set up the cell list
     double lcw = radius * 2;
-    pairs_sim->setup_sim(lcw, lcw, lcw, lcw);
+    pairs_sim->setup_cells(lcw, lcw, lcw, lcw);
     pairs_sim->update_mass_and_inertia();
     pairs_sim->communicate(0);
 
diff --git a/runtime/domain/ParticleDataHandling.hpp b/runtime/domain/ParticleDataHandling.hpp
index c54bae6..c13737c 100644
--- a/runtime/domain/ParticleDataHandling.hpp
+++ b/runtime/domain/ParticleDataHandling.hpp
@@ -290,7 +290,7 @@ public:
         
         // TODO: Check if there is enough particle capacity for the new particles, when there is not,
         // all properties and arrays which have particle_capacity as one of their dimensions must be reallocated
-        PAIRS_ASSERT(nlocal + nrecv < particle_capacity);
+        // PAIRS_ASSERT(nlocal + nrecv < particle_capacity);
 
         for(int i = 0; i < nrecv; ++i) {
             for(auto &prop: ps->getProperties()) {
diff --git a/runtime/domain/block_forest.cpp b/runtime/domain/block_forest.cpp
index 1ad9922..0327536 100644
--- a/runtime/domain/block_forest.cpp
+++ b/runtime/domain/block_forest.cpp
@@ -60,7 +60,7 @@ void BlockForest::updateNeighborhood() {
             auto neighbor_rank = walberla::int_c(block->getNeighborProcess(neigh));
 
             // TODO: Make PBCs work with runtime load balancing
-            // if(neighbor_rank != me) {
+            if(neighbor_rank != me) {
                 const walberla::BlockID& neighbor_id = block->getNeighborId(neigh);
                 walberla::math::AABB neighbor_aabb = block->getNeighborAABB(neigh);
                 auto begin = blocks_pushed[neighbor_rank].begin();
@@ -70,7 +70,7 @@ void BlockForest::updateNeighborhood() {
                     neighborhood[neighbor_rank].push_back(neighbor_aabb);
                     blocks_pushed[neighbor_rank].push_back(neighbor_id);
                 }
-            // }
+            }
         }
     }
 
@@ -265,9 +265,9 @@ void BlockForest::initialize(int *argc, char ***argv) {
     this->info = make_shared<walberla::blockforest::InfoCollection>();
 
     if (rank==0) {
+        std::cout << "Domain Partitioner: BlockForest" << std::endl;
         std::cout << "Domain: " << domain << std::endl;
-        std::cout << "PBC: " << pbc << std::endl;
-        std::cout << "Block config: " << block_config  << std::endl;
+        std::cout << "Configuration: " << block_config  << std::endl;
         std::cout << "Initial refinement level: " << ref_level << std::endl;
         std::cout << "Dynamic load balancing: " << (balance_workload ? "True" : "False") << std::endl;
     }
diff --git a/runtime/domain/regular_6d_stencil.cpp b/runtime/domain/regular_6d_stencil.cpp
index 96ea998..699ea93 100644
--- a/runtime/domain/regular_6d_stencil.cpp
+++ b/runtime/domain/regular_6d_stencil.cpp
@@ -87,6 +87,12 @@ void Regular6DStencil::initialize(int *argc, char ***argv) {
     MPI_Comm_rank(MPI_COMM_WORLD, &rank);
     this->setConfig();
     this->setBoundingBox();
+    if (rank==0) {
+        std::cout << "Domain Partitioner: Regular-6D" << std::endl;
+        std::cout << "Domain: [ <"  << subdom_min[0] << "," << subdom_min[1] << "," << subdom_min[2] << ">, <"
+                                    << subdom_max[0] << "," << subdom_max[1] << "," << subdom_max[2] << "> ]"<< std::endl;
+        std::cout << "Configuration: <" << nranks[0] << "," <<  nranks[1] << "," << nranks[2] << ">" <<std::endl;
+    }
 }
 
 void Regular6DStencil::initWorkloadBalancer(LoadBalancingAlgorithms algorithm, size_t regridMin, size_t regridMax) {}
diff --git a/runtime/timers.hpp b/runtime/timers.hpp
index e82fd32..4c7c3db 100644
--- a/runtime/timers.hpp
+++ b/runtime/timers.hpp
@@ -92,7 +92,8 @@ public:
         // Modules
         for (size_t i = TimerMarkers::Offset; i < time_counters.size(); ++i) {
             const std::string& counterName = counter_names[i];
-            if(counterName.find("INTERFACE_MODULES::") == 0) {
+            // if(counterName.find("INTERFACE_MODULES::") == 0) {
+            if(counterName.length() > 0) {
                 std::cout << std::left << std::setw(80) << counter_names[i]
                         << std::left << std::setw(15) << std::fixed << std::setprecision(2) << time_counters[i]
                         << std::left << std::setw(15) << call_counters[i]
@@ -108,10 +109,20 @@ public:
                     << "\n";
         }
 
+        computeCategories();
+        
+        // Categories
+        for (const auto& cs : categorySums) {;
+            std::cout << std::left << std::setw(80) << cs.first
+                    << std::left << std::setw(15) << std::fixed << std::setprecision(2) << cs.second
+                    << std::left << std::setw(15) << 1
+                    << "\n";
+        }
         std::cout << "--------------------------------------------------------------------------------------------------------\n";
     }
 
     void computeCategories() {
+        categorySums.clear();
         for (size_t i = 0; i < time_counters.size(); ++i) {
             const std::string& counterName = counter_names[i];
             TimeType counterValue = time_counters[i];
diff --git a/runtime/utility/dem_sc_grid.cpp b/runtime/utility/dem_sc_grid.cpp
index 30b2cb0..ffee437 100644
--- a/runtime/utility/dem_sc_grid.cpp
+++ b/runtime/utility/dem_sc_grid.cpp
@@ -31,6 +31,7 @@ int dem_sc_grid(PairsRuntime *ps, double xmax, double ymax, double zmax, double
     auto positions = ps->getAsVectorProperty(ps->getPropertyByName("position"));
     auto velocities = ps->getAsVectorProperty(ps->getPropertyByName("linear_velocity"));
     int nparticles = ps->getTrackedVariableAsInteger("nlocal");
+    int particle_capacity = ps->getTrackedVariableAsInteger("particle_capacity");
 
     const double xmin = 0.0;
     const double ymin = 0.0;
@@ -60,6 +61,11 @@ int dem_sc_grid(PairsRuntime *ps, double xmax, double ymax, double zmax, double
 
         if(ps->getDomainPartitioner()->isWithinSubdomain(point[0], point[1], point[2])) {
             real_t rad = pdiam * 0.5;
+            if(nparticles >= particle_capacity) {
+                std::cerr << "Number of particles exceeded capacity (" << particle_capacity << ") in rank " << ps->getDomainPartitioner()->getRank() << std::endl;
+                // TODO: resize properties, and all arrays that have particle_capacity as a dimension
+                exit(-1);
+            }
             uids(nparticles) = UniqueID::create(ps);
             radius(nparticles) = rad;
             masses(nparticles) = ((4.0 / 3.0) * M_PI) * rad * rad * rad * particle_density;
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index cbde4f8..e23865a 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -13,7 +13,7 @@ def simulation(
     double_prec=False,
     use_contact_history=False,
     particle_capacity=800000,
-    neighbor_capacity=100,
+    neighbor_capacity=20,
     debug=False):
 
     return Simulation(
diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
index 1f15c49..8c0b727 100644
--- a/src/pairs/code_gen/interface.py
+++ b/src/pairs/code_gen/interface.py
@@ -188,7 +188,7 @@ class InterfaceModules:
             PrintCode(self.sim, "LIKWID_MARKER_CLOSE;")
             
         Call_Void(self.sim, "pairs::print_timers", [])
-        Call_Void(self.sim, "pairs::log_timers", [])
+        # Call_Void(self.sim, "pairs::log_timers", [])
         Call_Void(self.sim, "pairs::print_stats", [self.sim.nlocal, self.sim.nghost])
         PrintCode(self.sim, "delete pobj;")
         PrintCode(self.sim, "delete pairs_runtime;")
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index f2b1e18..062dfe3 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -408,5 +408,5 @@ def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, compute_gl
             
     # User defined functions are wrapped inside seperate interface modules here.
     # The udf's have the same name as their interface module but they get implemented in the pairs::internal scope.
-    sim.build_interface_module_with_statements(run_on_device)  
+    sim.build_interface_module_with_statements()  
     
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index 8501984..09eefb0 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -51,6 +51,11 @@ class CellLists:
         self.stencil            =   self.sim.add_array('stencil', self.nstencil_capacity, Types.Int32)
         self.particle_cell      =   self.sim.add_array('particle_cell', self.sim.particle_capacity, Types.Int32)
 
+        if sim._use_halo_cells:
+            self.halo_ncells        =   self.sim.add_var('halo_ncells', Types.Int32, 0)
+            self.halo_ncells_capacity = self.sim.add_var('halo_ncells_capacity', Types.Int32, 10000)
+            self.halo_cells         =   self.sim.add_array('halo_cells', self.halo_ncells_capacity, Types.Int32)
+
         if sim._store_neighbors_per_cell:
             self.cell_neigh_capacity = self.sim.add_var('cell_neigh_capacity', Types.Int32, 80)
             self.cell_nneighs = self.sim.add_array('cell_nneighs', [self.ncells_capacity, self.sim.max_shapes()], Types.Int32)
@@ -86,7 +91,7 @@ class BuildCellListsStencil(Lowerable):
         for dim in range(self.sim.ndims()):
             dim_min = self.cell_lists.dom_part.min(dim) - spacing[dim]
             dim_max = self.cell_lists.dom_part.max(dim) + spacing[dim]
-            Assign(self.sim, dim_ncells[dim], Ceil(self.sim, (dim_max - dim_min) / spacing[dim]) + 1)
+            Assign(self.sim, dim_ncells[dim], Ceil(self.sim, (dim_max - dim_min) / spacing[dim]))
             ntotal_cells *= dim_ncells[dim]
 
         Assign(self.sim, ncells, ntotal_cells + 1)
@@ -102,6 +107,62 @@ class BuildCellListsStencil(Lowerable):
                         Assign(self.sim, stencil[nstencil], index)
                         Assign(self.sim, nstencil, nstencil + 1)
 
+        # Halo cell generation
+        # TODO: Defer halo cells generation to dom_part
+        # ----------------------------------------------------
+        if self.sim._use_halo_cells:
+            halo_ncells_capacity = self.cell_lists.halo_ncells_capacity
+            n = self.cell_lists.halo_ncells
+            self.sim.check_resize(halo_ncells_capacity, n)
+            halo_cells = self.cell_lists.halo_cells
+            Assign(self.sim, n, 1)
+
+            # Note: We add +2 to each layer since it's possible that the outermost local layer of  
+            # master doesn't fully overlap with innermost ghost layer of neighbor, and vice versa.
+            layers_0 = self.sim.add_temp_var(0) 
+            Assign(self.sim, layers_0, Ceil(self.sim, (cutoff_radius / spacing[0])) + 2)
+            layers_1 = self.sim.add_temp_var(0)
+            Assign(self.sim, layers_1, Ceil(self.sim, (cutoff_radius / spacing[1])) + 2)
+            layers_2 = self.sim.add_temp_var(0)
+            Assign(self.sim, layers_2, Ceil(self.sim, (cutoff_radius / spacing[2])) + 2)
+
+            # TODO: Merge these loops.
+            # X faces
+            for y in For(self.sim, 0, dim_ncells[1]):
+                for z in For(self.sim, 0, dim_ncells[2]):
+                    for x in For(self.sim, 0, layers_0):
+                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                        Assign(self.sim, halo_cells[n], index + 1)
+                        Assign(self.sim, n, n+1)
+                    for x in For(self.sim, dim_ncells[0]-layers_0, dim_ncells[0]):
+                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                        Assign(self.sim, halo_cells[n], index + 1)
+                        Assign(self.sim, n, n+1)
+            
+            # Y faces (excluding X edges)
+            for x in For(self.sim, layers_0, dim_ncells[0]-layers_0):
+                for z in For(self.sim, 0, dim_ncells[2]):
+                    for y in For(self.sim, 0, layers_1):
+                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                        Assign(self.sim, halo_cells[n], index + 1)
+                        Assign(self.sim, n, n+1)
+                    for y in For(self.sim, dim_ncells[1]-layers_1, dim_ncells[1]):
+                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                        Assign(self.sim, halo_cells[n], index + 1)
+                        Assign(self.sim, n, n+1)
+            
+            # Z faces (exluding X and Y edges)
+            for x in For(self.sim, layers_0, dim_ncells[0]-layers_0):
+                for y in For(self.sim, layers_1, dim_ncells[1]-layers_1):
+                    for z in For(self.sim, 0, layers_2):
+                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                        Assign(self.sim, halo_cells[n], index + 1)
+                        Assign(self.sim, n, n+1)
+                    for z in For(self.sim, dim_ncells[2]-layers_2, dim_ncells[2]):
+                        index = x*dim_ncells[1]*dim_ncells[2] + y*dim_ncells[2] + z
+                        Assign(self.sim, halo_cells[n], index + 1)
+                        Assign(self.sim, n, n+1)
+        
 
 class BuildCellLists(Lowerable):
     def __init__(self, sim, cell_lists):
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 00622af..54e4e0c 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -319,10 +319,9 @@ class DetermineGhostParticles(Lowerable):
         is_exchange = (self.spacing == 0.0) # TODO: module_params(self.spacing)
         ghost_or_exchg = "exchange" if is_exchange else "ghost"
         self.sim.module_name(f"determine_{ghost_or_exchg}_particles{self.step}")
-        self.sim.check_resize(self.comm.send_capacity, nsend)
-        #self.sim.check_resize(self.comm.send_capacity, nsend_all)
+        # self.sim.check_resize(self.comm.send_capacity, nsend)
+        self.sim.check_resize(self.comm.send_capacity, nsend_all)   # send_buffer packs data for all ranks
 
-        # PrintCode(self.sim, f"std::cout << \"resizes[0] {self.sim._module_name} ========== \" << pobj->resizes[0] << std::endl;")
         if is_exchange:
             for i in ParticleFor(self.sim):
                 Assign(self.sim, exchg_flag[i], 0)
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index f90f61f..fc5b87b 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -1,6 +1,6 @@
 from pairs.ir.assign import Assign
 from pairs.ir.branches import Filter
-from pairs.ir.loops import For
+from pairs.ir.loops import For, Continue
 from pairs.ir.functions import Call_Int, Call_Void, Call
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.select import Select
@@ -11,6 +11,7 @@ from pairs.sim.grid import MutableGrid
 from pairs.ir.device import CopyArray
 from pairs.ir.contexts import Contexts
 from pairs.ir.actions import Actions
+from pairs.ir.print import Print
 
 class DimensionRanges:
     def __init__(self, sim):
@@ -90,13 +91,70 @@ class DimensionRanges:
             for _ in Filter(self.sim, ScalarOp.inline(ScalarOp.cmp(self.pbc[j], 0))):
                 yield from prev_neighbor(self, j, step, position, offset, flags_to_exclude)
 
+    #### cell lists need to be rebuilt in each phase
+    #### --------------------------------------------------------------------
+    # def ghost_particles_halo_cells(self, step, position, offset=0.0):
+    #     # Particles with one of the following flags are ignored
+    #     flags_to_exclude = (Flags.Infinite | Flags.Global)
+
+    #     def next_neighbor(self, j, step, position, offset, flags_to_exclude):
+    #         particle_flags = self.sim.particle_flags
+    #         cells_to_check = self.sim.cell_lists.halo_cells
+    #         ncells_to_check = self.sim.cell_lists.halo_ncells
+    #         for nc in For(self.sim, 0, ncells_to_check):
+    #             c = self.sim.add_temp_var(0)
+    #             Assign(self.sim, c, cells_to_check[nc])
+    #             for p in For(self.sim, 0, self.sim.cell_lists.cell_sizes[c]):
+    #                 i = self.sim.cell_lists.cell_particles[c][p]
+    #                 particle_flags = self.sim.particle_flags
+
+    #                 # Don't check ghost particles if in Exchange mode
+    #                 if offset==0.0:
+    #                     for _ in Filter(self.sim, i >= self.sim.nlocal):
+    #                         Continue(self.sim)()
+
+    #                 for _ in Filter(self.sim, ScalarOp.cmp(particle_flags[i] & flags_to_exclude, 0)):
+    #                     for _ in Filter(self.sim, position[i][step] < self.subdom[j] + offset):
+    #                         pbc_shifts = [0 if d != step else self.pbc[j] for d in range(self.sim.ndims())]
+    #                         yield i, j, self.neighbor_ranks[j], pbc_shifts
+
+    #     def prev_neighbor(self, j, step, position, offset, flags_to_exclude):
+    #         particle_flags = self.sim.particle_flags
+    #         cells_to_check = self.sim.cell_lists.halo_cells
+    #         ncells_to_check = self.sim.cell_lists.halo_ncells
+    #         for nc in For(self.sim, 0, ncells_to_check):
+    #             c = self.sim.add_temp_var(0)
+    #             Assign(self.sim, c, cells_to_check[nc])
+    #             for p in For(self.sim, 0, self.sim.cell_lists.cell_sizes[c]):
+    #                 i = self.sim.cell_lists.cell_particles[c][p]
+    #                 particle_flags = self.sim.particle_flags
+                    
+    #                 # Don't check ghost particles if in Exchange mode
+    #                 if offset==0.0:
+    #                     for _ in Filter(self.sim, i >= self.sim.nlocal):
+    #                         Continue(self.sim)()
+                            
+    #                 for _ in Filter(self.sim, ScalarOp.cmp(particle_flags[i] & flags_to_exclude, 0)):
+    #                     for _ in Filter(self.sim, position[i][step] > self.subdom[j] - offset):
+    #                         pbc_shifts = [0 if d != step else self.pbc[j] for d in range(self.sim.ndims())]
+    #                         yield i, j, self.neighbor_ranks[j], pbc_shifts
+
+    #     if self.sim._pbc[step]:
+    #         yield from next_neighbor(self, step * 2 + 0, step, position, offset, flags_to_exclude)
+    #         yield from prev_neighbor(self, step * 2 + 1, step, position, offset, flags_to_exclude)
+
+    #     else:
+    #         j = step * 2 + 0
+    #         for _ in Filter(self.sim, ScalarOp.inline(ScalarOp.cmp(self.pbc[j], 0))):
+    #             yield from next_neighbor(self, j, step, position, offset, flags_to_exclude)
+
+    #         j = step * 2 + 1
+    #         for _ in Filter(self.sim, ScalarOp.inline(ScalarOp.cmp(self.pbc[j], 0))):
+    #             yield from prev_neighbor(self, j, step, position, offset, flags_to_exclude)
 
 class BlockForest:
     def __init__(self, sim):
         self.sim                = sim
-        self.load_balancer      = None
-        self.regrid_min         = None
-        self.regrid_max         = None
         self.reduce_step        = sim.add_var('reduce_step', Types.Int32)   # this var is treated as a tmp (workaround for gpu)
         self.reduce_step.force_read = True
         self.rank               = sim.add_var('rank', Types.Int32)
@@ -169,6 +227,13 @@ class BlockForest:
                 Assign(self.sim, self.sim.grid.max(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMax", [d], Types.Real))
 
     def ghost_particles(self, step, position, offset=0.0):
+        if self.sim._use_halo_cells:
+            # No support for adaptive blocks yet
+            yield from self.ghost_particles_halo_cells(step, position, offset)
+        else:
+            yield from self.ghost_particles_original(step, position, offset)
+
+    def ghost_particles_original(self, step, position, offset=0.0):
         ''' TODO :  If we have pbc, a sinlge particle can be a ghost particle multiple times (at different locations) for the same neighbor block,
                     so this function should have the capability to yield more than one particle for every neighbor.
                     But currently it doesn't have that capability, so we need at least 2 blocks in the dimensions that we have pbc.
@@ -231,3 +296,72 @@ class BlockForest:
                             
                             for _ in Filter(self.sim, isghost):
                                 yield i, r, self.ranks[r], pbc_shifts
+
+
+    def ghost_particles_halo_cells(self, step, position, offset=0.0):
+        # Particles with one of the following flags are ignored
+        flags_to_exclude = (Flags.Infinite | Flags.Global)
+        cells_to_check = self.sim.cell_lists.halo_cells
+        ncells_to_check = self.sim.cell_lists.halo_ncells
+        for r in self.step_indexes(0):     # for every neighbor rank
+            for nc in For(self.sim, 0, ncells_to_check):
+                c = self.sim.add_temp_var(0)
+                Assign(self.sim, c, cells_to_check[nc])
+                for p in For(self.sim, 0, self.sim.cell_lists.cell_sizes[c]):
+                    i = self.sim.cell_lists.cell_particles[c][p]
+                    particle_flags = self.sim.particle_flags
+
+                    # Skip ghost particles
+                    for _ in Filter(self.sim, i >= self.sim.nlocal):
+                        Continue(self.sim)()
+
+                    for _ in Filter(self.sim, ScalarOp.cmp(particle_flags[i] & flags_to_exclude, 0)):
+                        for aabb_id in For(self.sim, self.aabb_offsets[r], self.aabb_offsets[r] + self.naabbs[r]): # for every aabb of this neighbor
+                            for _ in Filter(self.sim, ScalarOp.neq(self.ranks[r] , self.rank)):     # if my neighobr is not my own rank
+                                full_cond = None
+                                pbc_shifts = []
+
+                                for d in range(self.sim.ndims()):
+                                    aabb_min = self.aabbs[aabb_id][d * 2 + 0]
+                                    aabb_max = self.aabbs[aabb_id][d * 2 + 1]
+                                    d_pbc = 0
+                                    d_length = self.sim.grid.length(d)
+
+                                    if self.sim._pbc[d]:
+                                        center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
+                                        dist = position[i][d] - center                      # distance of our particle from center of neighbor
+                                        cond_pbc_neg = dist >  (d_length * 0.5)
+                                        cond_pbc_pos = dist < -(d_length * 0.5)
+
+                                        d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
+
+                                    adj_pos = position[i][d] + d_pbc * d_length 
+                                    d_cond = ScalarOp.and_op(adj_pos > aabb_min - offset, adj_pos < aabb_max + offset)
+                                    full_cond = d_cond if full_cond is None else ScalarOp.and_op(full_cond, d_cond)
+                                    pbc_shifts.append(d_pbc)
+
+                                for _ in Filter(self.sim, full_cond):
+                                    yield i, r, self.ranks[r], pbc_shifts
+
+                            for _ in Filter(self.sim, ScalarOp.cmp(self.ranks[r] , self.rank)):     # if my neighbor is me
+                                pbc_shifts = []
+                                isghost = Lit(self.sim, 0)
+
+                                for d in range(self.sim.ndims()):
+                                    aabb_min = self.aabbs[aabb_id][d * 2 + 0]
+                                    aabb_max = self.aabbs[aabb_id][d * 2 + 1]
+                                    center = aabb_min + (aabb_max - aabb_min) * 0.5     # center of neighbor block
+                                    dist = position[i][d] - center                      # distance of our particle from center of neighbor
+                                    d_pbc = 0
+                                    d_length = self.sim.grid.length(d)
+
+                                    if self.sim._pbc[d]:
+                                        cond_pbc_neg = dist >  (d_length*0.5 - offset)
+                                        cond_pbc_pos = dist < -(d_length*0.5 - offset)
+                                        d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
+                                        isghost = ScalarOp.or_op(isghost, d_pbc)
+
+                                    pbc_shifts.append(d_pbc)
+                                
+                                for _ in Filter(self.sim, isghost):
+                                    yield i, r, self.ranks[r], pbc_shifts
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 3f8f084..0d5f2f8 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -62,6 +62,7 @@ class Simulation:
         self.grid = None
 
         # Acceleration structures
+        self._use_halo_cells = False
         self.cell_lists = None
         self._store_neighbors_per_cell = False
         self.neighbor_lists = None
@@ -259,10 +260,11 @@ class Simulation:
     def reneighbor_every(self, frequency):
         self.reneighbor_frequency = frequency
 
-    def build_cell_lists(self, spacing=None, store_neighbors_per_cell=False):
+    def build_cell_lists(self, spacing=None, store_neighbors_per_cell=False, use_halo_cells=False):
         """Add routines to build the linked-cells acceleration structure.
         Leave spacing as None so it can be set at runtime."""
         self._store_neighbors_per_cell = store_neighbors_per_cell
+        self._use_halo_cells = use_halo_cells
         self.cell_lists = CellLists(self, self._dom_part, spacing, spacing)
         return self.cell_lists
 
@@ -302,13 +304,12 @@ class Simulation:
         else:
             raise Exception("Two sizes assigned to same capacity!")
 
-    def build_interface_module_with_statements(self, run_on_device=False):
+    def build_interface_module_with_statements(self):
         """Build a user-defined Module that will be callable seperately as part of the interface"""
         Module(self, name=self._module_name,
                 block=Block(self, self._block),
                 resizes_to_check=self._resizes_to_check,
                 check_properties_resize=self._check_properties_resize,
-                run_on_device=run_on_device,
                 interface=True)
         
     def capture_statements(self, capture=True):
-- 
GitLab