diff --git a/ast/loops.py b/ast/loops.py index 817b6215f08671ed6ab58be5a74ed50b1a885605..915a3365a20e5940d4086f4e160b3e0ced629a20 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -106,14 +106,16 @@ class For(): class ParticleFor(For): - def __init__(self, sim, block=None): + def __init__(self, sim, block=None, local_only=True): super().__init__(sim, 0, 0, block) + self.local_only = local_only def __str__(self): return f"ParticleFor<>" def generate(self): - self.sim.code_gen.generate_for_preamble(self.iterator.generate(), 0, self.sim.nlocal.generate()) + upper_range = self.sim.nlocal if self.local_only else self.sim.nlocal + self.sim.pbc.npbc + self.sim.code_gen.generate_for_preamble(self.iterator.generate(), 0, upper_range.generate()) self.block.generate() self.sim.code_gen.generate_for_epilogue() diff --git a/sim/cell_lists.py b/sim/cell_lists.py index 28ab65d2efb457efaf5d73357065b5c8b0295998..cd412cde096dc8f21f7ded474dc6023ee236d2cb 100644 --- a/sim/cell_lists.py +++ b/sim/cell_lists.py @@ -84,7 +84,7 @@ class CellListsBuild: for c in For(cl.sim, 0, cl.ncells_all): cl.cell_sizes[c].set(0) - for i in ParticleFor(cl.sim): + for i in ParticleFor(cl.sim, local_only=False): cell_index = [ Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / spc) for d in range(0, cl.sim.dimensions)] @@ -95,9 +95,7 @@ class CellListsBuild: else flat_idx * cl.ncells[d] + cell_index[d]) cell_size = cl.cell_sizes[flat_idx] - for _ in Filter(cl.sim, - Expr.and_op(flat_idx >= 0, - flat_idx <= cl.ncells_all)): + for _ in Filter(cl.sim, Expr.and_op(flat_idx >= 0, flat_idx <= cl.ncells_all)): for cond in Branch(cl.sim, cell_size >= cl.cell_capacity): if cond: resize.set(cell_size)