From e40a66620f668f2982f806a9ea2988e9ee5d2f73 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Tue, 30 Aug 2022 20:10:16 +0200
Subject: [PATCH] First fixes for communication code

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/domain/domain_partitioning.hpp |   1 -
 src/pairs/code_gen/cgen.py             |   3 +
 src/pairs/ir/loops.py                  |   9 +-
 src/pairs/ir/variables.py              |  25 ++--
 src/pairs/sim/comm.py                  | 155 +++++++++++++++++--------
 src/pairs/sim/domain_partitioning.py   |   8 +-
 src/pairs/sim/simulation.py            |  16 ++-
 src/pairs/sim/vtk.py                   |   2 +-
 8 files changed, 145 insertions(+), 74 deletions(-)

diff --git a/runtime/domain/domain_partitioning.hpp b/runtime/domain/domain_partitioning.hpp
index 066694e..1464be8 100644
--- a/runtime/domain/domain_partitioning.hpp
+++ b/runtime/domain/domain_partitioning.hpp
@@ -109,7 +109,6 @@ public:
             send_offset += next_nsend * elem_size;
             recv_offset += prev_nrecv * elem_size;
         }
-
     }
 };
 
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index dfec0db..faa0acb 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -512,6 +512,9 @@ class CGen:
                 return f"({lhs} {operator.symbol()} {rhs})"
 
             if ast_node.is_vector_kind():
+                if index is None:
+                    print(ast_node)
+
                 assert index is not None, "Index must be set for vector reference!"
                 return f"e{ast_node.id()}[{index}]" if ast_node.mem else f"e{ast_node.id()}_{index}"
 
diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 46abc61..b2ea4af 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -28,10 +28,7 @@ class Iter(ASTTerm):
         return Types.Int32
 
     def __eq__(self, other):
-        if isinstance(other, Iter):
-            return self.iter_id == other.iter_id
-
-        return False
+        return isinstance(other, Iter) and self.iter_id == other.iter_id
 
     def __req__(self, other):
         return self.__cmp__(other)
@@ -69,14 +66,14 @@ class For(ASTNode):
 
 class ParticleFor(For):
     def __init__(self, sim, block=None, local_only=True):
-        super().__init__(sim, 0, sim.nlocal if local_only else sim.nlocal + sim.comm.nghost, block)
+        super().__init__(sim, 0, sim.nlocal if local_only else sim.nlocal + sim.nghost, block)
         self.local_only = local_only
 
     def __str__(self):
         return f"ParticleFor<self.iterator>"
 
     def children(self):
-        return [self.block, self.sim.nlocal] + ([] if self.local_only else [self.sim.comm.nghost])
+        return [self.block, self.sim.nlocal] + ([] if self.local_only else [self.sim.nghost])
 
 
 class While(ASTNode):
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index da13df1..764c6d2 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -1,28 +1,39 @@
 from pairs.ir.ast_node import ASTNode
 from pairs.ir.assign import Assign
 from pairs.ir.bin_op import ASTTerm 
+from pairs.ir.lit import Lit
 
 
 class Variables:
+    temp_id = 0
+
+    def new_temp_id():
+        Variables.temp_id += 1
+        return Variables.temp_id - 1
+
     def __init__(self, sim):
         self.sim = sim
         self.vars = []
         self.nvars = 0
 
     def add(self, v_name, v_type, v_value=0):
-        v = Var(self.sim, v_name, v_type, v_value)
-        self.vars.append(v)
-        return v
+        var = Var(self.sim, v_name, v_type, v_value)
+        self.vars.append(var)
+        return var
+
+    def add_temp(self, init):
+        lit = Lit.cvt(self.sim, init)
+        tmp_id = Variables.new_temp_id()
+        tmp_var = Var(self.sim, f"tmp{tmp_id}", lit.type())
+        self.sim.add_statement(Assign(self.sim, tmp_var, lit))
+        return tmp_var
 
     def all(self):
         return self.vars
 
     def find(self, v_name):
         var = [v for v in self.vars if v.name() == v_name]
-        if var:
-            return var[0]
-
-        return None
+        return var[0] if var else None
 
 
 class Var(ASTTerm):
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index dae1bea..2239cc3 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -1,8 +1,9 @@
 from pairs.ir.atomic import AtomicAdd
-from pairs.ir.block import pairs_device_block, pairs_host_block
+from pairs.ir.bin_op import BinOp
+from pairs.ir.block import pairs_device_block, pairs_host_block, pairs_inline
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.cast import Cast
-from pairs.ir.loops import For, ParticleFor
+from pairs.ir.loops import For, ParticleFor, While
 from pairs.ir.utils import Print
 from pairs.ir.select import Select
 from pairs.ir.types import Types
@@ -10,9 +11,9 @@ from pairs.sim.lowerable import Lowerable
 
 
 class Comm:
-    def __init__(self, sim, max_neigh_ranks=6, max_buffer_elems=7):
+    def __init__(self, sim, dom_part, max_neigh_ranks=6, max_buffer_elems=7):
         self.sim = sim
-        self.nghost        = sim.add_var('nghost', Types.Int32)
+        self.dom_part = dom_part
         self.nsend_all     = sim.add_var('nsend_all', Types.Int32)
         self.send_capacity = sim.add_var('send_capacity', Types.Int32, 100)
         self.recv_capacity = sim.add_var('recv_capacity', Types.Int32, 100)
@@ -27,40 +28,94 @@ class Comm:
         self.recv_map      = sim.add_array('recv_map', [self.recv_capacity], Types.Int32)
         self.recv_mult     = sim.add_array('recv_mult', [self.recv_capacity, sim.ndims()], Types.Int32)
 
+    @pairs_inline
+    def synchronize(self):
+        prop_list = [self.sim.property(p) for p in ['position']]
+        PackGhostParticles(self, prop_list)
+        CommunicateData(self, prop_list)
+        UnpackGhostParticles(self, prop_list)
+
+    @pairs_inline
+    def borders(self):
+        prop_list = [self.sim.property(p) for p in ['mass', 'position']]
+        for d in For(self.sim, 0, self.sim.ndims()):
+            DetermineGhostParticles(self, d, self.sim.cell_spacing())
+            PackGhostParticles(self, prop_list)
+            CommunicateSizes(self)
+            CommunicateData(self, prop_list)
+            UnpackGhostParticles(self, prop_list)
+
+    @pairs_inline
+    def exchange(self):
+        prop_list = [self.sim.property(p) for p in ['mass', 'position', 'velocity']]
+        for d in For(self.sim, 0, self.sim.ndims()):
+            DetermineGhostParticles(self, d, 0.0)
+            PackGhostParticles(self, prop_list)
+            RemoveExchangedParticles_part1(self)
+            RemoveExchangedParticles_part2(self, prop_list)
+            CommunicateSizes(self)
+            CommunicateData(self, prop_list)
+            ChangeSizeAfterExchange(self)
+            UnpackGhostParticles(self, prop_list)
+
+
+class CommunicateSizes(Lowerable):
+    def __init__(self, comm):
+        super().__init__(comm.sim)
+        self.comm = comm
+
+    @pairs_inline
+    def lower(self):
+        Call_Void(self.sim, "pairs->communicateSizes", [self.nsend, self.nrecv])
+
+
+class CommunicateData(Lowerable):
+    def __init__(self, comm, prop_list):
+        super().__init__(comm.sim)
+        self.comm = comm
+        self.prop_list = prop_list
+
+    @pairs_inline
+    def lower(self):
+        elem_size = sum([self.sim.ndims() if p.type() == Types.Vector else 1])
+        Call_Void(self.sim, "pairs->communicateData", [self.send_buffer, self.nsend, self.recv_buffer, self.nrecv, elem_size])
+
 
 class DetermineGhostParticles(Lowerable):
-    def __init__(self, sim, comm, dom_part, step, spacing):
-        super().__init__(sim)
+    def __init__(self, comm, step, spacing):
+        super().__init__(comm.sim)
         self.comm = comm
-        self.dom_part = dom_part
         self.step = step
         self.spacing = spacing
+        self.sim.add_statement(self)
 
     @pairs_device_block
     def lower(self):
+        nsend_all = self.comm.nsend_all
         nsend = self.comm.nsend
         send_map = self.comm.send_map
         send_mult = self.comm.send_mult
-        sim.module_name("determine_ghost_particles")
-        sim.check_resize(self.comm.send_capacity, nsend)
+        self.sim.module_name("determine_ghost_particles")
+        self.sim.check_resize(self.comm.send_capacity, nsend)
 
         nb_rank_id = 0
         nsend_all.set(0)
-        for i, _, pbc in self.dom_part.ghost_particles(step, self.sim.position(), self.spacing):
+        for i, _, pbc in self.comm.dom_part.ghost_particles(self.step, self.sim.position(), self.spacing):
             n = AtomicAdd(self.sim, nsend_all, 1)
             send_map[n].set(i)
-            for d in self.sim.ndims():
+            for d in range(self.sim.ndims()):
                 send_mult[n][d].set(pbc[d])
 
-            self.nsend[nb_rank_id].add(1)
+            nsend[nb_rank_id].add(1)
             nb_rank_id += 1
 
 
 class PackGhostParticles(Lowerable):
-    def __init__(self, sim, comm, prop_list):
-        super().__init__(sim)
+    def __init__(self, comm, prop_list):
+        super().__init__(comm.sim)
         self.comm = comm
         self.prop_list = prop_list
+        self.sim.add_statement(self)
 
     def get_elems_per_particle(self):
         return sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list])
@@ -69,17 +124,18 @@ class PackGhostParticles(Lowerable):
     def lower(self):
         send_buffer = self.comm.send_buffer
         send_map = self.comm.send_map
+        send_mult = self.comm.send_mult
         elems_per_particle = self.get_elems_per_particle()
-        sim.module_name("pack_ghost_particles" + sum(["_{p.id()}" for p in self.prop_list]))
+        self.sim.module_name("pack_ghost_particles" + "_".join([str(p.id()) for p in self.prop_list]))
 
-        for i in For(self.sim, self.comm.nsend):
+        for i in For(self.sim, 0, self.comm.nsend):
             p_offset = 0
             for p in self.prop_list:
                 if p.type() == Types.Vector:
-                    for d in self.sim.ndims():
+                    for d in range(self.sim.ndims()):
                         src = p[send_map[i]][d]
                         if p == self.sim.position():
-                            src += send_mult[i][d] * grid.length(d)
+                            src += send_mult[i][d] * self.sim.grid.length(d)
 
                         send_buffer[i * elems_per_particle + p_offset + d].set(src)
 
@@ -92,10 +148,11 @@ class PackGhostParticles(Lowerable):
 
             
 class UnpackGhostParticles(Lowerable):
-    def __init__(self, sim, comm, prop_list):
-        super().__init__(sim)
+    def __init__(self, comm, prop_list):
+        super().__init__(comm.sim)
         self.comm = comm
         self.prop_list = prop_list
+        self.sim.add_statement(self)
 
     def get_elems_per_particle(self):
         return sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list])
@@ -105,13 +162,13 @@ class UnpackGhostParticles(Lowerable):
         nlocal = self.sim.nlocal
         recv_buffer = self.comm.recv_buffer
         elems_per_particle = self.get_elems_per_particle()
-        sim.module_name("unpack_ghost_particles" + sum(["_{p.id()}" for p in self.prop_list]))
+        self.sim.module_name("unpack_ghost_particles" + "_".join([str(p.id()) for p in self.prop_list]))
 
-        for i in For(self.sim, self.comm.nrecv):
+        for i in For(self.sim, 0, self.comm.nrecv):
             p_offset = 0
             for p in self.prop_list:
                 if p.type() == Types.Vector:
-                    for d in self.sim.ndims():
+                    for d in range(self.sim.ndims()):
                         p[nlocal + i][d].set(recv_buffer[i * elems_per_particle + p_offset + d])
                         
                     p_offset += self.sim.ndims()
@@ -123,45 +180,45 @@ class UnpackGhostParticles(Lowerable):
 
 
 class RemoveExchangedParticles_part1(Lowerable):
-    def __init__(self, sim, comm, prop_list):
-        super().__init__(sim)
+    def __init__(self, comm):
+        super().__init__(comm.sim)
         self.comm = comm
+        self.sim.add_statement(self)
 
     @pairs_host_block
     def lower(self):
-        send_pos = sim.add_temp_var(grid.nparticles)
-        sim.module_name("remove_exchanged_particles_pt1")
-
-        for i in For(self.sim, self.comm.nsend):
+        send_pos = self.sim.add_temp_var(self.sim.nparticles)
+        self.sim.module_name("remove_exchanged_particles_pt1")
+        for i in For(self.sim, 0, self.comm.nsend):
             for is_local in Branch(self.sim, self.comm.send_map[i] < self.sim.nlocal - self.comm.nsend):
                 if is_local:
-                    for _ in While(self.sim, BinOp.eq(self.comm.exchg_flags[send_pos], 1)):
-                        send_pos.assign(send_pos - 1)
+                    for _ in While(self.sim, BinOp.cmp(self.comm.exchg_flag[send_pos], 1)):
+                        send_pos.set(send_pos - 1)
 
                     self.comm.exchg_copy_to[i].set(send_pos)
-                    send_pos.assign(send_pos - 1)
+                    send_pos.set(send_pos - 1)
 
                 else:
                     self.comm.exchg_copy_to[i].set(-1)
 
 
 class RemoveExchangedParticles_part2(Lowerable):
-    def __init__(self, sim, comm, prop_list):
-        super().__init__(sim)
+    def __init__(self, comm, prop_list):
+        super().__init__(comm.sim)
         self.comm = comm
         self.prop_list = prop_list
+        self.sim.add_statement(self)
 
     @pairs_device_block
     def lower(self):
-        sim.module_name("remove_exchanged_particles_pt2")
-
+        self.sim.module_name("remove_exchanged_particles_pt2")
         for i in ParticleFor(self.sim):
             src = self.comm.exchg_copy_to[i]
             for _ in Filter(self.sim, src > 0):
                 dst = self.comm.send_map[i]
                 for p in self.prop_list:
                     if p.type() == Types.Vector:
-                        for d in self.sim.ndims():
+                        for d in range(self.sim.ndims()):
                             p[dst][d].set(p[src][d])
 
                     else:
@@ -170,19 +227,15 @@ class RemoveExchangedParticles_part2(Lowerable):
         self.sim.nlocal.set(self.sim.nlocal - self.comm.nsend)
 
 
-class Synchronize(Lowerable):
-    def __init__(self, sim, comm):
-        super().__init__(sim)
+class ChangeSizeAfterExchange(Lowerable):
+    def __init__(self, comm):
+        super().__init__(comm.sim)
         self.comm = comm
+        self.sim.add_statement(self)
 
-
-class Borders(Lowerable):
-    def __init__(self, sim, comm):
-        super().__init__(sim)
-        self.comm = comm
-
-
-class Exchange(Lowerable):
-    def __init__(self, sim, comm):
-        super().__init__(sim)
-        self.comm = comm
+    @pairs_host_block
+    def lower(self):
+        sim = self.sim
+        sim.module_name("change_size_after_exchange")
+        sim.check_resize(self.sim.particle_capacity, self.sim.nlocal)
+        self.sim.nlocal.set(sim.nlocal + self.comm.nrecv)
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index 2606911..af3c54c 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -18,12 +18,12 @@ class DimensionRanges:
         return self.sim.ndims()
 
     def ghost_particles(self, step, position, offset=0.0):
-        for i in For(sim, 0, self.sim.nlocal + self.sim.comm.nghost):
+        for i in For(self.sim, 0, self.sim.nlocal + self.sim.nghost):
             j = step * 2 + 0
             for _ in Filter(self.sim, position < self.subdom[j] + offset):
-                yield i, self.neighbor_ranks[j], [0 if d != step else self.pbc[j] for d in self.sim.ndims()]
+                yield i, self.neighbor_ranks[j], [0 if d != step else self.pbc[j] for d in range(self.sim.ndims())]
 
-        for i in For(sim, 0, self.sim.nlocal + self.sim.comm.nghost):
+        for i in For(self.sim, 0, self.sim.nlocal + self.sim.nghost):
             j = step * 2 + 1
             for _ in Filter(self.sim, position > self.subdom[j] - offset):
-                yield i, self.neighbor_ranks[j], [0 if d != step else self.pbc[j] for d in self.sim.ndims()]
+                yield i, self.neighbor_ranks[j], [0 if d != step else self.pbc[j] for d in range(self.sim.ndims())]
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 971eda4..f5df6d3 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -12,7 +12,8 @@ from pairs.graph.graphviz import ASTGraph
 from pairs.mapping.funcs import compute
 from pairs.sim.arrays import ArraysDecl
 from pairs.sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild
-from pairs.sim.comm import Comm, DetermineGhostParticles, ExchangeParticles, UpdateGhostParticles
+from pairs.sim.comm import Comm
+from pairs.sim.domain_partitioning import DimensionRanges
 from pairs.sim.grid import Grid2D, Grid3D
 from pairs.sim.lattice import ParticleLattice
 from pairs.sim.neighbor_lists import NeighborLists, NeighborListsBuild
@@ -35,6 +36,7 @@ class Simulation:
         self.arrays = Arrays(self)
         self.particle_capacity = self.add_var('particle_capacity', Types.Int32, particle_capacity)
         self.nlocal = self.add_var('nlocal', Types.Int32)
+        self.nghost = self.add_var('nghost', Types.Int32)
         self.resizes = self.add_array('resizes', 3, Types.Int32, arr_sync=False)
         self.grid = None
         self.cell_lists = None
@@ -57,8 +59,7 @@ class Simulation:
         self.iter_id = 0
         self.vtk_file = None
         self._target = None
-        self.comm = Comm(self)
-        self.nparticles = self.nlocal + self.comm.nghost
+        self.nparticles = self.nlocal + self.nghost
         self.properties.add_capacity(self.particle_capacity)
 
     def add_module(self, module):
@@ -120,6 +121,9 @@ class Simulation:
         assert self.var(var_name) is None, f"Variable already defined: {var_name}"
         return self.vars.add(var_name, var_type, init_value)
 
+    def add_temp_var(self, init_value):
+        return self.vars.add_temp(init_value)
+
     def add_symbol(self, sym_type):
         return Symbol(self, sym_type)
 
@@ -222,9 +226,13 @@ class Simulation:
     def generate(self):
         assert self._target is not None, "Target not specified!"
 
+        dom_part = DimensionRanges(self)
+        comm = Comm(self, dom_part)
+
         timestep = Timestep(self, self.ntimesteps, [
+            (comm.exchange(), 20),
             (EnforcePBC(self), 20),
-            (DetermineGhostParticles(self, self.comm), UpdateGhostParticles(self, self.comm), 20),
+            (comm.borders(), comm.synchronize(), 20),
             (CellListsBuild(self, self.cell_lists), 20),
             (NeighborListsBuild(self, self.neighbor_lists), 20),
             PropertiesResetVolatile(self),
diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py
index 0faf8ec..806afc8 100644
--- a/src/pairs/sim/vtk.py
+++ b/src/pairs/sim/vtk.py
@@ -14,7 +14,7 @@ class VTKWrite(Lowerable):
     @pairs_inline
     def lower(self):
         nlocal = self.sim.nlocal
-        nghost = self.sim.comm.nghost
+        nghost = self.sim.nghost
         nall = nlocal + nghost
         Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_local", 0, nlocal, self.timestep])
         Call_Void(self.sim, "pairs::vtk_write_data", [self.filename + "_ghost", nlocal, nall, self.timestep])
-- 
GitLab