From efa97a253f2bd7a7680da89bf2543149d75c5dd4 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 7 Oct 2022 17:33:37 +0200
Subject: [PATCH] Small fixes into exchange

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/sim/comm.py       | 34 +++++++++++++++++++++++++++-------
 src/pairs/sim/simulation.py |  2 +-
 2 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index b95000c..a881d04 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -16,8 +16,8 @@ class Comm:
         self.sim = sim
         self.dom_part = dom_part
         self.nsend_all      = sim.add_var('nsend_all', Types.Int32)
-        self.send_capacity  = sim.add_var('send_capacity', Types.Int32, 10000)
-        self.recv_capacity  = sim.add_var('recv_capacity', Types.Int32, 10000)
+        self.send_capacity  = sim.add_var('send_capacity', Types.Int32, 100000)
+        self.recv_capacity  = sim.add_var('recv_capacity', Types.Int32, 100000)
         self.elem_capacity  = sim.add_var('elem_capacity', Types.Int32, 10)
         self.neigh_capacity = sim.add_var('neigh_capacity', Types.Int32, 6)
         self.nsend          = sim.add_array('nsend', [self.neigh_capacity], Types.Int32)
@@ -60,14 +60,23 @@ class Comm:
         prop_list = [self.sim.property(p) for p in ['mass', 'position', 'velocity']]
         for step in range(self.dom_part.number_of_steps()):
             self.nsend_all.set(0)
+            self.sim.nghost.set(0)
+            for s in range(step):
+                for j in self.dom_part.step_indexes(s):
+                    self.nsend[j].set(0)
+                    self.nrecv[j].set(0)
+                    self.send_offsets[j].set(0)
+                    self.recv_offsets[j].set(0)
+
             DetermineGhostParticles(self, step, 0.0)
+            CommunicateSizes(self, step)
+            SetCommunicationOffsets(self, step)
             PackGhostParticles(self, step, prop_list)
             RemoveExchangedParticles_part1(self)
             RemoveExchangedParticles_part2(self, prop_list)
-            CommunicateSizes(self, step)
             CommunicateData(self, step, prop_list)
-            ChangeSizeAfterExchange(self, step)
             UnpackGhostParticles(self, step, prop_list)
+            ChangeSizeAfterExchange(self, step)
 
 
 class CommunicateSizes(Lowerable):
@@ -113,7 +122,9 @@ class DetermineGhostParticles(Lowerable):
         nrecv = self.comm.nrecv
         send_map = self.comm.send_map
         send_mult = self.comm.send_mult
-        ghost_or_exchg = "exchange" if self.spacing == 0.0 else "ghost" # TODO: module_params(self.spacing)
+        exchg_flag = self.comm.exchg_flag
+        is_exchange = (self.spacing == 0.0) # TODO: module_params(self.spacing)
+        ghost_or_exchg = "exchange" if is_exchange else "ghost"
         self.sim.module_name(f"determine_{ghost_or_exchg}_particles{self.step}")
         self.sim.check_resize(self.comm.send_capacity, nsend)
 
@@ -121,9 +132,17 @@ class DetermineGhostParticles(Lowerable):
             nsend[j].set(0)
             nrecv[j].set(0)
 
+        if is_exchange:
+            for i in ParticleFor(self.sim):
+                exchg_flag[i].set(0)
+
         for i, j, _, pbc in self.comm.dom_part.ghost_particles(self.step, self.sim.position(), self.spacing):
             next_idx = AtomicAdd(self.sim, nsend_all, 1)
             send_map[next_idx].set(i)
+
+            if is_exchange:
+                exchg_flag[i].set(1)
+
             for d in range(self.sim.ndims()):
                 send_mult[next_idx][d].set(pbc[d])
 
@@ -244,9 +263,10 @@ class RemoveExchangedParticles_part1(Lowerable):
     @pairs_host_block
     def lower(self):
         self.sim.module_name("remove_exchanged_particles_pt1")
-        send_pos = self.sim.add_temp_var(self.sim.nparticles)
+        send_pos = self.sim.add_temp_var(self.sim.nlocal)
         for i in For(self.sim, 0, self.comm.nsend_all):
-            for need_copy in Branch(self.sim, self.comm.send_map[i] < self.sim.nlocal - self.comm.nsend_all):
+            particle_id = self.comm.send_map[i]
+            for need_copy in Branch(self.sim, particle_id < self.sim.nlocal - self.comm.nsend_all):
                 if need_copy:
                     for _ in While(self.sim, BinOp.cmp(self.comm.exchg_flag[send_pos], 1)):
                         send_pos.set(send_pos - 1)
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 2e73b7b..8f2ab56 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -28,7 +28,7 @@ from pairs.transformations import Transformations
 
 
 class Simulation:
-    def __init__(self, code_gen, dims=3, timesteps=100, particle_capacity=10000):
+    def __init__(self, code_gen, dims=3, timesteps=100, particle_capacity=1000000):
         self.code_gen = code_gen
         self.code_gen.assign_simulation(self)
         self.position_prop = None
-- 
GitLab