From 49668121f3b4c8be16ed47307a40511728d8361b Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Mon, 4 Dec 2023 20:53:21 +0100
Subject: [PATCH] Reduce device transfers for DEM case

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/array.hpp          |  2 +-
 runtime/pairs.cpp          | 61 ++++++++++++--------------------------
 runtime/pairs.hpp          | 59 ++++++++++++++++++++++++++++--------
 src/pairs/code_gen/cgen.py |  8 ++++-
 src/pairs/ir/device.py     |  6 +++-
 src/pairs/sim/comm.py      | 20 +++++++++++--
 6 files changed, 96 insertions(+), 60 deletions(-)

diff --git a/runtime/array.hpp b/runtime/array.hpp
index 1a755c4..03a5385 100644
--- a/runtime/array.hpp
+++ b/runtime/array.hpp
@@ -24,7 +24,7 @@ public:
         PAIRS_ASSERT(size_ > 0);
     }
 
-    property_t getId() { return id; }
+    array_t getId() { return id; }
     std::string getName() { return name; }
     void *getHostPointer() { return h_ptr; }
     void *getDevicePointer() { return d_ptr; }
diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index d4291b8..643dfa4 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -164,15 +164,11 @@ void PairsSimulation::copyArrayToDevice(Array &array, action_t action, size_t si
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !array_flags->isDeviceFlagSet(array_id)) {
             if(array.isStatic()) {
-                PAIRS_DEBUG("Copying static array %s to device\n", array.getName().c_str());
-                pairs::copy_static_symbol_to_device(
-                    array.getHostPointer(), array.getDevicePointer(), array.getSize());
+                PAIRS_DEBUG("Copying static array %s to device (n=%d)\n", array.getName().c_str(), size);
+                pairs::copy_static_symbol_to_device(array.getHostPointer(), array.getDevicePointer(), size);
             } else {
-                PAIRS_DEBUG("Copying array %s to device\n", array.getName().c_str());
-                pairs::copy_to_device(
-                    array.getHostPointer(),
-                    array.getDevicePointer(),
-                    (size > 0) ? size : array.getSize());
+                PAIRS_DEBUG("Copying array %s to device (n=%d)\n", array.getName().c_str(), size);
+                pairs::copy_to_device(array.getHostPointer(), array.getDevicePointer(), size);
             }
         }
     }
@@ -190,17 +186,12 @@ void PairsSimulation::copyArrayToHost(Array &array, action_t action, size_t size
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !array_flags->isHostFlagSet(array_id)) {
             if(array.isStatic()) {
-                PAIRS_DEBUG("Copying static array %s to host\n", array.getName().c_str());
-                pairs::copy_static_symbol_to_host(
-                    array.getDevicePointer(), array.getHostPointer(), array.getSize());
+                PAIRS_DEBUG("Copying static array %s to host (n=%d)\n", array.getName().c_str(), size);
+                pairs::copy_static_symbol_to_host(array.getDevicePointer(), array.getHostPointer(), size);
             } else {
-                PAIRS_DEBUG("Copying array %s to host\n", array.getName().c_str());
-                pairs::copy_to_host(
-                    array.getDevicePointer(),
-                    array.getHostPointer(),
-                    (size > 0) ? size : array.getSize());
+                PAIRS_DEBUG("Copying array %s to host (n=%d)\n", array.getName().c_str(), size);
+                pairs::copy_to_host(array.getDevicePointer(), array.getHostPointer(), size);
             }
-
         }
     }
 
@@ -216,11 +207,8 @@ void PairsSimulation::copyPropertyToDevice(Property &prop, action_t action, size
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !prop_flags->isDeviceFlagSet(prop_id)) {
-            PAIRS_DEBUG("Copying property %s to device\n", prop.getName().c_str());
-            pairs::copy_to_device(
-                prop.getHostPointer(),
-                prop.getDevicePointer(),
-                (size > 0) ? size : prop.getTotalSize());
+            PAIRS_DEBUG("Copying property %s to device (n=%d)\n", prop.getName().c_str(), size);
+            pairs::copy_to_device(prop.getHostPointer(), prop.getDevicePointer(), size);
         }
     }
 
@@ -236,11 +224,8 @@ void PairsSimulation::copyPropertyToHost(Property &prop, action_t action, size_t
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !prop_flags->isHostFlagSet(prop_id)) {
-            PAIRS_DEBUG("Copying property %s to host\n", prop.getName().c_str());
-            pairs::copy_to_host(
-                prop.getDevicePointer(),
-                prop.getHostPointer(),
-                (size > 0) ? size : prop.getTotalSize());
+            PAIRS_DEBUG("Copying property %s to host (n=%d)\n", prop.getName().c_str(), size);
+            pairs::copy_to_host(prop.getDevicePointer(), prop.getHostPointer(), size);
         }
     }
 
@@ -258,12 +243,8 @@ void PairsSimulation::copyContactPropertyToDevice(
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(action == Ignore || !contact_prop_flags->isDeviceFlagSet(prop_id)) {
-            PAIRS_DEBUG("Copying contact property %s to device\n", contact_prop.getName().c_str());
-            pairs::copy_to_device(
-                contact_prop.getHostPointer(),
-                contact_prop.getDevicePointer(),
-                (size > 0) ? size : contact_prop.getTotalSize());
-
+            PAIRS_DEBUG("Copying contact property %s to device (n=%d)\n", contact_prop.getName().c_str(), size);
+            pairs::copy_to_device(contact_prop.getHostPointer(), contact_prop.getDevicePointer(), size);
             contact_prop_flags->setDeviceFlag(prop_id);
         }
     }
@@ -280,12 +261,8 @@ void PairsSimulation::copyContactPropertyToHost(
 
     if(action == Ignore || action == WriteAfterRead || action == ReadOnly) {
         if(!contact_prop_flags->isHostFlagSet(contact_prop.getId())) {
-            PAIRS_DEBUG("Copying contact property %s to host\n", contact_prop.getName().c_str());
-            pairs::copy_to_host(
-                contact_prop.getDevicePointer(),
-                contact_prop.getHostPointer(),
-                (size > 0) ? size : contact_prop.getTotalSize());
-
+            PAIRS_DEBUG("Copying contact property %s to host (n=%d)\n", contact_prop.getName().c_str(), size);
+            pairs::copy_to_host(contact_prop.getDevicePointer(), contact_prop.getHostPointer(), size);
             contact_prop_flags->setHostFlag(prop_id);
         }
     }
@@ -296,9 +273,9 @@ void PairsSimulation::copyContactPropertyToHost(
 }
 
 void PairsSimulation::copyFeaturePropertyToDevice(FeatureProperty &feature_prop) {
-    PAIRS_DEBUG("Copying static array %s to device\n", feature_prop.getName().c_str());
-    pairs::copy_static_symbol_to_device(
-        feature_prop.getHostPointer(), feature_prop.getDevicePointer(), feature_prop.getArraySize());
+    const size_t n = feature_prop.getArraySize();
+    PAIRS_DEBUG("Copying feature property %s to device (n=%d)\n", feature_prop.getName().c_str(), n);
+    pairs::copy_static_symbol_to_device(feature_prop.getHostPointer(), feature_prop.getDevicePointer(), n);
 }
 
 void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv_sizes) {
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index 3a8a58f..c99a11a 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -83,16 +83,27 @@ public:
     template<typename T_ptr>
     void reallocArray(array_t id, T_ptr **h_ptr, T_ptr **d_ptr, size_t size);
 
-    void copyArrayToDevice(array_t id, action_t action, size_t size = 0) {
+    void copyArrayToDevice(array_t id, action_t action) {
+        auto& array = getArray(id);
+        copyArrayToDevice(array, action, array.getSize());
+    }
+
+    void copyArrayToDevice(array_t id, action_t action, size_t size) {
         copyArrayToDevice(getArray(id), action, size);
     }
 
-    void copyArrayToDevice(Array &array, action_t action, size_t size = 0);
-    void copyArrayToHost(array_t id, action_t action, size_t size = 0) {
+    void copyArrayToDevice(Array &array, action_t action, size_t size);
+
+    void copyArrayToHost(array_t id, action_t action) {
+        auto& array = getArray(id);
+        copyArrayToHost(array, action, array.getSize());
+    }
+
+    void copyArrayToHost(array_t id, action_t action, size_t size) {
         copyArrayToHost(getArray(id), action, size);
     }
 
-    void copyArrayToHost(Array &array, action_t action, size_t size = 0);
+    void copyArrayToHost(Array &array, action_t action, size_t size);
 
     // Properties
     Property &getProperty(property_t id);
@@ -154,17 +165,31 @@ public:
         return static_cast<QuaternionProperty&>(getProperty(property));
     }
 
-    void copyPropertyToDevice(property_t id, action_t action, size_t size = 0) {
+    void copyPropertyToDevice(property_t id, action_t action) {
+        auto& prop = getProperty(id);
+        copyPropertyToDevice(prop, action, prop.getTotalSize());
+    }
+
+    void copyPropertyToDevice(property_t id, action_t action, size_t size) {
         copyPropertyToDevice(getProperty(id), action, size);
     }
 
-    void copyPropertyToDevice(Property &prop, action_t action, size_t size = 0);
+    void copyPropertyToDevice(Property &prop, action_t action, size_t size);
 
-    void copyPropertyToHost(property_t id, action_t action, size_t size = 0) {
+    void copyPropertyToHost(property_t id, action_t action) {
+        auto& prop = getProperty(id);
+        copyPropertyToHost(prop, action, prop.getTotalSize());
+    }
+
+    void copyPropertyToHost(property_t id, action_t action, size_t size) {
         copyPropertyToHost(getProperty(id), action, size);
     }
 
-    void copyPropertyToHost(Property &prop, action_t action, size_t size = 0);
+    void copyPropertyToHost(Property &prop, action_t action) {
+        copyPropertyToHost(prop, action, prop.getTotalSize());
+    }
+
+    void copyPropertyToHost(Property &prop, action_t action, size_t size);
 
     // Contact properties
     ContactProperty &getContactProperty(property_t id);
@@ -189,17 +214,27 @@ public:
     void reallocContactProperty(
         property_t id, T_ptr **h_ptr, T_ptr **d_ptr, size_t sx = 1, size_t sy = 1);
 
-    void copyContactPropertyToDevice(property_t id, action_t action, size_t size = 0) {
+    void copyContactPropertyToDevice(property_t id, action_t action) {
+        auto& contact_prop = getContactProperty(id);
+        copyContactPropertyToDevice(contact_prop, action, contact_prop.getTotalSize());
+    }
+
+    void copyContactPropertyToDevice(property_t id, action_t action, size_t size) {
         copyContactPropertyToDevice(getContactProperty(id), action, size);
     }
 
-    void copyContactPropertyToDevice(ContactProperty &prop, action_t action, size_t size = 0);
+    void copyContactPropertyToDevice(ContactProperty &prop, action_t action, size_t size);
+
+    void copyContactPropertyToHost(property_t id, action_t action) {
+        auto& contact_prop = getContactProperty(id);
+        copyContactPropertyToHost(contact_prop, action, contact_prop.getTotalSize());
+    }
 
-    void copyContactPropertyToHost(property_t id, action_t action, size_t size = 0) {
+    void copyContactPropertyToHost(property_t id, action_t action, size_t size) {
         copyContactPropertyToHost(getContactProperty(id), action, size);
     }
 
-    void copyContactPropertyToHost(ContactProperty &prop, action_t action, size_t size = 0);
+    void copyContactPropertyToHost(ContactProperty &prop, action_t action, size_t size);
 
     // Feature properties
     FeatureProperty &getFeatureProperty(property_t id);
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 3b36c03..adf3fb9 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -474,7 +474,13 @@ class CGen:
             array_name = ast_node.array().name()
             ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host"
             action = Actions.c_keyword(ast_node.action())
-            self.print(f"pairs->copyArrayTo{ctx_suffix}({array_id}, {action}); // {array_name}")
+            size = self.generate_expression(ast_node.size())
+
+            if size is not None:
+                self.print(f"pairs->copyArrayTo{ctx_suffix}({array_id}, {action}, {size}); // {array_name}")
+
+            else:
+                self.print(f"pairs->copyArrayTo{ctx_suffix}({array_id}, {action}); // {array_name}")
 
         if isinstance(ast_node, CopyContactProperty):
             prop_id = ast_node.contact_prop().id()
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index c52613d..952ff14 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -30,11 +30,12 @@ class DeviceStaticRef(ASTNode):
 
 
 class CopyArray(ASTNode):
-    def __init__(self, sim, array, ctx, action):
+    def __init__(self, sim, array, ctx, action, size=None):
         super().__init__(sim)
         self._array = array
         self._ctx = ctx
         self._action = action
+        self._size = ScalarOp.inline(size)
         self.sim.add_statement(self)
 
     def array(self):
@@ -46,6 +47,9 @@ class CopyArray(ASTNode):
     def action(self):
         return self._action
 
+    def size(self):
+        return self._size
+
     def children(self):
         return [self._array]
 
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 3f552ed..c6429e8 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -11,6 +11,7 @@ from pairs.ir.functions import Call_Void
 from pairs.ir.loops import For, ParticleFor, While
 from pairs.ir.utils import Print
 from pairs.ir.select import Select
+from pairs.ir.sizeof import Sizeof
 from pairs.ir.types import Types
 from pairs.sim.lowerable import Lowerable
 
@@ -27,9 +28,9 @@ class Comm:
         self.nsend          = sim.add_array('nsend', [self.neigh_capacity], Types.Int32)
         self.send_offsets   = sim.add_array('send_offsets', [self.neigh_capacity], Types.Int32)
         self.send_buffer    = sim.add_array('send_buffer', [self.send_capacity, self.elem_capacity], Types.Real, arr_sync=False)
-        self.send_map       = sim.add_array('send_map', [self.send_capacity], Types.Int32)
-        self.exchg_flag     = sim.add_array('exchg_flag', [sim.particle_capacity], Types.Int32)
-        self.exchg_copy_to  = sim.add_array('exchg_copy_to', [self.send_capacity], Types.Int32)
+        self.send_map       = sim.add_array('send_map', [self.send_capacity], Types.Int32, arr_sync=False)
+        self.exchg_flag     = sim.add_array('exchg_flag', [sim.particle_capacity], Types.Int32, arr_sync=False)
+        self.exchg_copy_to  = sim.add_array('exchg_copy_to', [self.send_capacity], Types.Int32, arr_sync=False)
         self.send_mult      = sim.add_array('send_mult', [self.send_capacity, sim.ndims()], Types.Int32)
         self.nrecv          = sim.add_array('nrecv', [self.neigh_capacity], Types.Int32)
         self.recv_offsets   = sim.add_array('recv_offsets', [self.neigh_capacity], Types.Int32)
@@ -116,7 +117,20 @@ class Comm:
             CommunicateSizes(self, step)
             SetCommunicationOffsets(self, step)
             PackGhostParticles(self, step, prop_list)
+
+            if self.sim._target.is_gpu():
+                send_map_size = self.nsend_all * Sizeof(self.sim, Types.Int32)
+                exchg_flag_size = self.sim.nlocal * Sizeof(self.sim, Types.Int32)
+                CopyArray(self.sim, self.send_map, Contexts.Host, Actions.ReadOnly, send_map_size)
+                CopyArray(self.sim, self.exchg_flag, Contexts.Host, Actions.ReadOnly, exchg_flag_size)
+
             RemoveExchangedParticles_part1(self)
+
+            if self.sim._target.is_gpu():
+                exchg_copy_to_size = self.nsend_all * Sizeof(self.sim, Types.Int32)
+                CopyArray(
+                    self.sim, self.exchg_copy_to, Contexts.Device, Actions.ReadOnly, exchg_copy_to_size)
+
             RemoveExchangedParticles_part2(self, prop_list)
             CommunicateData(self, step, prop_list)
             UnpackGhostParticles(self, step, prop_list)
-- 
GitLab