Skip to content
Snippets Groups Projects
Commit 169860f3 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Adjust timers to comprise MPI and device transfers

parent c9f3c097
Branches
Tags
No related merge requests found
......@@ -301,10 +301,15 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv
auto nsend_id = getArrayByHostPointer(send_sizes).getId();
auto nrecv_id = getArrayByHostPointer(recv_sizes).getId();
this->getTimers()->start(DeviceTransfers);
copyArrayToHost(nsend_id, ReadOnly);
array_flags->setHostFlag(nrecv_id);
array_flags->clearDeviceFlag(nrecv_id);
this->getTimers()->stop(DeviceTransfers);
this->getTimers()->start(Communication);
this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes);
this->getTimers()->stop(Communication);
}
void PairsSimulation::communicateData(
......@@ -319,6 +324,7 @@ void PairsSimulation::communicateData(
auto nsend_id = getArrayByHostPointer(nsend).getId();
auto nrecv_id = getArrayByHostPointer(nrecv).getId();
this->getTimers()->start(DeviceTransfers);
copyArrayToHost(send_offsets_id, ReadOnly);
copyArrayToHost(recv_offsets_id, ReadOnly);
copyArrayToHost(nsend_id, ReadOnly);
......@@ -334,13 +340,18 @@ void PairsSimulation::communicateData(
}
copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
array_flags->setHostFlag(recv_buf_id);
array_flags->clearDeviceFlag(recv_buf_id);
this->getTimers()->stop(DeviceTransfers);
this->getTimers()->start(Communication);
this->getDomainPartitioner()->communicateData(
dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
this->getTimers()->stop(Communication);
this->getTimers()->start(DeviceTransfers);
copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t));
this->getTimers()->stop(DeviceTransfers);
}
void PairsSimulation::communicateAllData(
......@@ -355,6 +366,7 @@ void PairsSimulation::communicateAllData(
auto nsend_id = getArrayByHostPointer(nsend).getId();
auto nrecv_id = getArrayByHostPointer(nrecv).getId();
this->getTimers()->start(DeviceTransfers);
copyArrayToHost(send_offsets_id, ReadOnly);
copyArrayToHost(recv_offsets_id, ReadOnly);
copyArrayToHost(nsend_id, ReadOnly);
......@@ -370,13 +382,18 @@ void PairsSimulation::communicateAllData(
}
copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
array_flags->setHostFlag(recv_buf_id);
array_flags->clearDeviceFlag(recv_buf_id);
this->getTimers()->stop(DeviceTransfers);
this->getTimers()->start(Communication);
this->getDomainPartitioner()->communicateAllData(
ndims, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
this->getTimers()->stop(Communication);
this->getTimers()->start(DeviceTransfers);
copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t));
this->getTimers()->stop(DeviceTransfers);
}
void PairsSimulation::fillCommunicationArrays(int *neighbor_ranks, int *pbc, real_t *subdom) {
......
......@@ -37,6 +37,13 @@ enum Actions {
Ignore = 5
};
enum Timers {
All = 0,
Communication = 1,
DeviceTransfers = 2,
Offset = 3
};
enum DomainPartitioning {
DimRanges = 0,
BoxList,
......
class Timers:
Invalid = -1
All = 0
Communication = 1
DeviceTransfers = 2
Offset = 3
def name(timer):
return "all" if timer == Timers.All else \
"mpi" if timer == Timers.Communication else \
"transfers" if timer == Timers.DeviceTransfers else \
"invalid"
from pairs.ir.block import pairs_inline
from pairs.ir.functions import Call_Void
from pairs.ir.timers import Timers
from pairs.sim.lowerable import FinalLowerable
class RegisterTimers(FinalLowerable):
def __init__(self, sim):
self.sim = sim
@pairs_inline
def lower(self):
Call_Void(self.sim, "pairs::register_timer", [0, "all"])
for t in range(Timers.Offset):
Call_Void(self.sim, "pairs::register_timer", [t, Timers.name(t)])
for m in self.sim.module_list:
if m.name != 'main':
Call_Void(self.sim, "pairs::register_timer", [m.module_id + 1, m.name])
Call_Void(self.sim, "pairs::register_timer", [m.module_id + Timers.Offset, m.name])
class RegisterMarkers(FinalLowerable):
......
......@@ -3,6 +3,7 @@ from pairs.ir.block import Block
from pairs.ir.branches import Branch, Filter
from pairs.ir.functions import Call_Void
from pairs.ir.loops import For
from pairs.ir.timers import Timers
class Timestep:
......@@ -63,9 +64,9 @@ class Timestep:
_capture = self.sim._capture_statements
self.sim.capture_statements(False)
block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [0]),
block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [Timers.All]),
self.timestep_loop,
Call_Void(self.sim, "pairs::stop_timer", [0])])
Call_Void(self.sim, "pairs::stop_timer", [Timers.All])])
self.sim.capture_statements(_capture)
return block
......@@ -5,12 +5,14 @@ from pairs.ir.branches import Filter
from pairs.ir.cast import Cast
from pairs.ir.contexts import Contexts
from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
from pairs.ir.functions import Call_Void
from pairs.ir.kernel import Kernel, KernelLaunch
from pairs.ir.lit import Lit
from pairs.ir.loops import For
from pairs.ir.module import ModuleCall
from pairs.ir.mutator import Mutator
from pairs.ir.scalars import ScalarOp
from pairs.ir.timers import Timers
from pairs.ir.types import Types
......@@ -34,6 +36,9 @@ class AddDeviceCopies(Mutator):
if isinstance(s, ModuleCall):
copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host
clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device
new_stmts += [
Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
]
for array, action in s.module.arrays().items():
new_stmts += [CopyArray(s.sim, array, copy_context, action)]
......@@ -52,16 +57,27 @@ class AddDeviceCopies(Mutator):
if action != Actions.ReadOnly and var.device_flag:
new_stmts += [CopyVar(s.sim, var, Contexts.Device, action)]
new_stmts += [
Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
]
new_stmts.append(s)
if isinstance(s, ModuleCall):
if s.module.run_on_device:
new_stmts += [
Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
]
for var, action in s.module.variables().items():
if action != Actions.ReadOnly and var.device_flag:
new_stmts += [CopyVar(s.sim, var, Contexts.Host, action)]
if self.module_resizes[s.module]:
new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, Actions.Ignore)]
new_stmts += [
Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
]
ast_node.stmts = new_stmts
return ast_node
......
......@@ -2,6 +2,7 @@ from pairs.ir.block import Block
from pairs.ir.functions import Call_Void
from pairs.ir.module import ModuleCall
from pairs.ir.mutator import Mutator
from pairs.ir.timers import Timers
class AddModulesInstrumentation(Mutator):
......@@ -14,7 +15,7 @@ class AddModulesInstrumentation(Mutator):
if module.name == 'main':
return ast_node
timer_id = module.module_id + 1
timer_id = module.module_id + Timers.Offset
start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id])
stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment