From 169860f3e81caa5c0041cbe198af533f60b8b2f9 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 1 Dec 2023 23:25:48 +0100
Subject: [PATCH] Adjust timers to comprise MPI and device transfers

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/pairs.cpp                            | 21 ++++++++++++++++++--
 runtime/pairs_common.hpp                     |  7 +++++++
 src/pairs/ir/timers.py                       | 12 +++++++++++
 src/pairs/sim/instrumentation.py             |  7 ++++---
 src/pairs/sim/timestep.py                    |  5 +++--
 src/pairs/transformations/devices.py         | 16 +++++++++++++++
 src/pairs/transformations/instrumentation.py |  3 ++-
 7 files changed, 63 insertions(+), 8 deletions(-)
 create mode 100644 src/pairs/ir/timers.py

diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index 93c714c..24e40a0 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -301,10 +301,15 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv
     auto nsend_id = getArrayByHostPointer(send_sizes).getId();
     auto nrecv_id = getArrayByHostPointer(recv_sizes).getId();
 
+    this->getTimers()->start(DeviceTransfers);
     copyArrayToHost(nsend_id, ReadOnly);
     array_flags->setHostFlag(nrecv_id);
     array_flags->clearDeviceFlag(nrecv_id);
+    this->getTimers()->stop(DeviceTransfers);
+
+    this->getTimers()->start(Communication);
     this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes);
+    this->getTimers()->stop(Communication);
 }
 
 void PairsSimulation::communicateData(
@@ -319,6 +324,7 @@ void PairsSimulation::communicateData(
     auto nsend_id = getArrayByHostPointer(nsend).getId();
     auto nrecv_id = getArrayByHostPointer(nrecv).getId();
 
+    this->getTimers()->start(DeviceTransfers);
     copyArrayToHost(send_offsets_id, ReadOnly);
     copyArrayToHost(recv_offsets_id, ReadOnly);
     copyArrayToHost(nsend_id, ReadOnly);
@@ -334,13 +340,18 @@ void PairsSimulation::communicateData(
     }
 
     copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
-
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
+    this->getTimers()->stop(DeviceTransfers);
+
+    this->getTimers()->start(Communication);
     this->getDomainPartitioner()->communicateData(
         dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
+    this->getTimers()->stop(Communication);
 
+    this->getTimers()->start(DeviceTransfers);
     copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t));
+    this->getTimers()->stop(DeviceTransfers);
 }
 
 void PairsSimulation::communicateAllData(
@@ -355,6 +366,7 @@ void PairsSimulation::communicateAllData(
     auto nsend_id = getArrayByHostPointer(nsend).getId();
     auto nrecv_id = getArrayByHostPointer(nrecv).getId();
 
+    this->getTimers()->start(DeviceTransfers);
     copyArrayToHost(send_offsets_id, ReadOnly);
     copyArrayToHost(recv_offsets_id, ReadOnly);
     copyArrayToHost(nsend_id, ReadOnly);
@@ -370,13 +382,18 @@ void PairsSimulation::communicateAllData(
     }
 
     copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
-
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
+    this->getTimers()->stop(DeviceTransfers);
+
+    this->getTimers()->start(Communication);
     this->getDomainPartitioner()->communicateAllData(
         ndims, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
+    this->getTimers()->stop(Communication);
 
+    this->getTimers()->start(DeviceTransfers);
     copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t));
+    this->getTimers()->stop(DeviceTransfers);
 }
 
 void PairsSimulation::fillCommunicationArrays(int *neighbor_ranks, int *pbc, real_t *subdom) {
diff --git a/runtime/pairs_common.hpp b/runtime/pairs_common.hpp
index a209950..525c0e3 100644
--- a/runtime/pairs_common.hpp
+++ b/runtime/pairs_common.hpp
@@ -37,6 +37,13 @@ enum Actions {
     Ignore = 5
 };
 
+enum Timers {
+    All = 0,
+    Communication = 1,
+    DeviceTransfers = 2,
+    Offset = 3
+};
+
 enum DomainPartitioning {
     DimRanges = 0,
     BoxList,
diff --git a/src/pairs/ir/timers.py b/src/pairs/ir/timers.py
new file mode 100644
index 0000000..b1cd21c
--- /dev/null
+++ b/src/pairs/ir/timers.py
@@ -0,0 +1,12 @@
+class Timers:
+    Invalid = -1
+    All = 0
+    Communication = 1
+    DeviceTransfers = 2
+    Offset = 3
+
+    def name(timer):
+        return "all"            if timer == Timers.All else             \
+               "mpi"            if timer == Timers.Communication else   \
+               "transfers"      if timer == Timers.DeviceTransfers else \
+               "invalid"
diff --git a/src/pairs/sim/instrumentation.py b/src/pairs/sim/instrumentation.py
index a017c04..dedc7c1 100644
--- a/src/pairs/sim/instrumentation.py
+++ b/src/pairs/sim/instrumentation.py
@@ -1,19 +1,20 @@
 from pairs.ir.block import pairs_inline
 from pairs.ir.functions import Call_Void
+from pairs.ir.timers import Timers
 from pairs.sim.lowerable import FinalLowerable
 
-
 class RegisterTimers(FinalLowerable):
     def __init__(self, sim):
         self.sim = sim
 
     @pairs_inline
     def lower(self):
-        Call_Void(self.sim, "pairs::register_timer", [0, "all"])
+        for t in range(Timers.Offset):
+            Call_Void(self.sim, "pairs::register_timer", [t, Timers.name(t)])
 
         for m in self.sim.module_list:
             if m.name != 'main':
-                Call_Void(self.sim, "pairs::register_timer", [m.module_id + 1, m.name])
+                Call_Void(self.sim, "pairs::register_timer", [m.module_id + Timers.Offset, m.name])
 
 
 class RegisterMarkers(FinalLowerable):
diff --git a/src/pairs/sim/timestep.py b/src/pairs/sim/timestep.py
index 44e34ef..1281a4d 100644
--- a/src/pairs/sim/timestep.py
+++ b/src/pairs/sim/timestep.py
@@ -3,6 +3,7 @@ from pairs.ir.block import Block
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.functions import Call_Void
 from pairs.ir.loops import For
+from pairs.ir.timers import Timers
 
 
 class Timestep:
@@ -63,9 +64,9 @@ class Timestep:
         _capture = self.sim._capture_statements
         self.sim.capture_statements(False)
 
-        block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [0]),
+        block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [Timers.All]),
                                  self.timestep_loop,
-                                 Call_Void(self.sim, "pairs::stop_timer", [0])])
+                                 Call_Void(self.sim, "pairs::stop_timer", [Timers.All])])
 
         self.sim.capture_statements(_capture)
         return block
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 3e7c126..9e53b95 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -5,12 +5,14 @@ from pairs.ir.branches import Filter
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
 from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
+from pairs.ir.functions import Call_Void
 from pairs.ir.kernel import Kernel, KernelLaunch
 from pairs.ir.lit import Lit
 from pairs.ir.loops import For
 from pairs.ir.module import ModuleCall
 from pairs.ir.mutator import Mutator
 from pairs.ir.scalars import ScalarOp
+from pairs.ir.timers import Timers
 from pairs.ir.types import Types
 
 
@@ -34,6 +36,9 @@ class AddDeviceCopies(Mutator):
                 if isinstance(s, ModuleCall):
                     copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host
                     clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device
+                    new_stmts += [
+                        Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
+                    ]
 
                     for array, action in s.module.arrays().items():
                         new_stmts += [CopyArray(s.sim, array, copy_context, action)]
@@ -52,16 +57,27 @@ class AddDeviceCopies(Mutator):
                             if action != Actions.ReadOnly and var.device_flag:
                                 new_stmts += [CopyVar(s.sim, var, Contexts.Device, action)]
 
+                    new_stmts += [
+                        Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
+                    ]
+
                 new_stmts.append(s)
 
                 if isinstance(s, ModuleCall):
                     if s.module.run_on_device:
+                        new_stmts += [
+                            Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
+                        ]
+
                         for var, action in s.module.variables().items():
                             if action != Actions.ReadOnly and var.device_flag:
                                 new_stmts += [CopyVar(s.sim, var, Contexts.Host, action)]
 
                         if self.module_resizes[s.module]:
                             new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, Actions.Ignore)]
+                        new_stmts += [
+                            Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
+                        ]
 
         ast_node.stmts = new_stmts
         return ast_node
diff --git a/src/pairs/transformations/instrumentation.py b/src/pairs/transformations/instrumentation.py
index 53d0e52..1e70bdb 100644
--- a/src/pairs/transformations/instrumentation.py
+++ b/src/pairs/transformations/instrumentation.py
@@ -2,6 +2,7 @@ from pairs.ir.block import Block
 from pairs.ir.functions import Call_Void
 from pairs.ir.module import ModuleCall
 from pairs.ir.mutator import Mutator
+from pairs.ir.timers import Timers
 
 
 class AddModulesInstrumentation(Mutator):
@@ -14,7 +15,7 @@ class AddModulesInstrumentation(Mutator):
         if module.name == 'main':
             return ast_node
 
-        timer_id = module.module_id + 1
+        timer_id = module.module_id + Timers.Offset
         start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id])
         stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id])
 
-- 
GitLab