diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index 6adaa85c48b2c9562a5165db97195d578474d046..8a496f2a872200cdab54e0808b349c1e4c9ac56e 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -137,15 +137,18 @@ class BlockForest:
                         pbc_shifts = []
 
                         for d in range(self.sim.ndims()):
-                            cond_pbc_neg = position[i][d] - offset < self.sim.grid.min(d)
-                            cond_pbc_pos = position[i][d] + offset > self.sim.grid.max(d)
+                            aabb_min = self.aabbs[aabb_id][d * 2 + 0]
+                            aabb_max = self.aabbs[aabb_id][d * 2 + 1]
+                            center = aabb_min + (aabb_max - aabb_min) * 0.5
+                            dist = position[i][d] - center
+                            d_length = self.sim.grid.length(d)
+
+                            cond_pbc_neg = dist >  (d_length * 0.5)
+                            cond_pbc_pos = dist < -(d_length * 0.5)
                             d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0))
 
-                            adj_pos = position[i][d] + d_pbc * self.sim.grid.length(d)
-                            d_cond = ScalarOp.and_op(
-                                adj_pos > self.aabbs[aabb_id][d * 2 + 0] + offset,
-                                adj_pos < self.aabbs[aabb_id][d * 2 + 1] - offset)
-
+                            adj_pos = position[i][d] + d_pbc * d_length
+                            d_cond = ScalarOp.and_op(adj_pos > aabb_min - offset, adj_pos < aabb_max + offset)
                             full_cond = d_cond if full_cond is None else ScalarOp.and_op(full_cond, d_cond)
                             pbc_shifts.append(d_pbc)