diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py index 6adaa85c48b2c9562a5165db97195d578474d046..8a496f2a872200cdab54e0808b349c1e4c9ac56e 100644 --- a/src/pairs/sim/domain_partitioning.py +++ b/src/pairs/sim/domain_partitioning.py @@ -137,15 +137,18 @@ class BlockForest: pbc_shifts = [] for d in range(self.sim.ndims()): - cond_pbc_neg = position[i][d] - offset < self.sim.grid.min(d) - cond_pbc_pos = position[i][d] + offset > self.sim.grid.max(d) + aabb_min = self.aabbs[aabb_id][d * 2 + 0] + aabb_max = self.aabbs[aabb_id][d * 2 + 1] + center = aabb_min + (aabb_max - aabb_min) * 0.5 + dist = position[i][d] - center + d_length = self.sim.grid.length(d) + + cond_pbc_neg = dist > (d_length * 0.5) + cond_pbc_pos = dist < -(d_length * 0.5) d_pbc = Select(self.sim, cond_pbc_neg, -1, Select(self.sim, cond_pbc_pos, 1, 0)) - adj_pos = position[i][d] + d_pbc * self.sim.grid.length(d) - d_cond = ScalarOp.and_op( - adj_pos > self.aabbs[aabb_id][d * 2 + 0] + offset, - adj_pos < self.aabbs[aabb_id][d * 2 + 1] - offset) - + adj_pos = position[i][d] + d_pbc * d_length + d_cond = ScalarOp.and_op(adj_pos > aabb_min - offset, adj_pos < aabb_max + offset) full_cond = d_cond if full_cond is None else ScalarOp.and_op(full_cond, d_cond) pbc_shifts.append(d_pbc)