Skip to content
Snippets Groups Projects
Commit 169860f3 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Adjust timers to comprise MPI and device transfers

parent c9f3c097
No related branches found
No related tags found
No related merge requests found
...@@ -301,10 +301,15 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv ...@@ -301,10 +301,15 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv
auto nsend_id = getArrayByHostPointer(send_sizes).getId(); auto nsend_id = getArrayByHostPointer(send_sizes).getId();
auto nrecv_id = getArrayByHostPointer(recv_sizes).getId(); auto nrecv_id = getArrayByHostPointer(recv_sizes).getId();
this->getTimers()->start(DeviceTransfers);
copyArrayToHost(nsend_id, ReadOnly); copyArrayToHost(nsend_id, ReadOnly);
array_flags->setHostFlag(nrecv_id); array_flags->setHostFlag(nrecv_id);
array_flags->clearDeviceFlag(nrecv_id); array_flags->clearDeviceFlag(nrecv_id);
this->getTimers()->stop(DeviceTransfers);
this->getTimers()->start(Communication);
this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes); this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes);
this->getTimers()->stop(Communication);
} }
void PairsSimulation::communicateData( void PairsSimulation::communicateData(
...@@ -319,6 +324,7 @@ void PairsSimulation::communicateData( ...@@ -319,6 +324,7 @@ void PairsSimulation::communicateData(
auto nsend_id = getArrayByHostPointer(nsend).getId(); auto nsend_id = getArrayByHostPointer(nsend).getId();
auto nrecv_id = getArrayByHostPointer(nrecv).getId(); auto nrecv_id = getArrayByHostPointer(nrecv).getId();
this->getTimers()->start(DeviceTransfers);
copyArrayToHost(send_offsets_id, ReadOnly); copyArrayToHost(send_offsets_id, ReadOnly);
copyArrayToHost(recv_offsets_id, ReadOnly); copyArrayToHost(recv_offsets_id, ReadOnly);
copyArrayToHost(nsend_id, ReadOnly); copyArrayToHost(nsend_id, ReadOnly);
...@@ -334,13 +340,18 @@ void PairsSimulation::communicateData( ...@@ -334,13 +340,18 @@ void PairsSimulation::communicateData(
} }
copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t)); copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
array_flags->setHostFlag(recv_buf_id); array_flags->setHostFlag(recv_buf_id);
array_flags->clearDeviceFlag(recv_buf_id); array_flags->clearDeviceFlag(recv_buf_id);
this->getTimers()->stop(DeviceTransfers);
this->getTimers()->start(Communication);
this->getDomainPartitioner()->communicateData( this->getDomainPartitioner()->communicateData(
dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv); dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
this->getTimers()->stop(Communication);
this->getTimers()->start(DeviceTransfers);
copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t)); copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t));
this->getTimers()->stop(DeviceTransfers);
} }
void PairsSimulation::communicateAllData( void PairsSimulation::communicateAllData(
...@@ -355,6 +366,7 @@ void PairsSimulation::communicateAllData( ...@@ -355,6 +366,7 @@ void PairsSimulation::communicateAllData(
auto nsend_id = getArrayByHostPointer(nsend).getId(); auto nsend_id = getArrayByHostPointer(nsend).getId();
auto nrecv_id = getArrayByHostPointer(nrecv).getId(); auto nrecv_id = getArrayByHostPointer(nrecv).getId();
this->getTimers()->start(DeviceTransfers);
copyArrayToHost(send_offsets_id, ReadOnly); copyArrayToHost(send_offsets_id, ReadOnly);
copyArrayToHost(recv_offsets_id, ReadOnly); copyArrayToHost(recv_offsets_id, ReadOnly);
copyArrayToHost(nsend_id, ReadOnly); copyArrayToHost(nsend_id, ReadOnly);
...@@ -370,13 +382,18 @@ void PairsSimulation::communicateAllData( ...@@ -370,13 +382,18 @@ void PairsSimulation::communicateAllData(
} }
copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t)); copyArrayToHost(send_buf_id, Ignore, nsend_all * elem_size * sizeof(real_t));
array_flags->setHostFlag(recv_buf_id); array_flags->setHostFlag(recv_buf_id);
array_flags->clearDeviceFlag(recv_buf_id); array_flags->clearDeviceFlag(recv_buf_id);
this->getTimers()->stop(DeviceTransfers);
this->getTimers()->start(Communication);
this->getDomainPartitioner()->communicateAllData( this->getDomainPartitioner()->communicateAllData(
ndims, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv); ndims, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
this->getTimers()->stop(Communication);
this->getTimers()->start(DeviceTransfers);
copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t)); copyArrayToDevice(recv_buf_id, Ignore, nrecv_all * elem_size * sizeof(real_t));
this->getTimers()->stop(DeviceTransfers);
} }
void PairsSimulation::fillCommunicationArrays(int *neighbor_ranks, int *pbc, real_t *subdom) { void PairsSimulation::fillCommunicationArrays(int *neighbor_ranks, int *pbc, real_t *subdom) {
......
...@@ -37,6 +37,13 @@ enum Actions { ...@@ -37,6 +37,13 @@ enum Actions {
Ignore = 5 Ignore = 5
}; };
enum Timers {
All = 0,
Communication = 1,
DeviceTransfers = 2,
Offset = 3
};
enum DomainPartitioning { enum DomainPartitioning {
DimRanges = 0, DimRanges = 0,
BoxList, BoxList,
......
class Timers:
Invalid = -1
All = 0
Communication = 1
DeviceTransfers = 2
Offset = 3
def name(timer):
return "all" if timer == Timers.All else \
"mpi" if timer == Timers.Communication else \
"transfers" if timer == Timers.DeviceTransfers else \
"invalid"
from pairs.ir.block import pairs_inline from pairs.ir.block import pairs_inline
from pairs.ir.functions import Call_Void from pairs.ir.functions import Call_Void
from pairs.ir.timers import Timers
from pairs.sim.lowerable import FinalLowerable from pairs.sim.lowerable import FinalLowerable
class RegisterTimers(FinalLowerable): class RegisterTimers(FinalLowerable):
def __init__(self, sim): def __init__(self, sim):
self.sim = sim self.sim = sim
@pairs_inline @pairs_inline
def lower(self): def lower(self):
Call_Void(self.sim, "pairs::register_timer", [0, "all"]) for t in range(Timers.Offset):
Call_Void(self.sim, "pairs::register_timer", [t, Timers.name(t)])
for m in self.sim.module_list: for m in self.sim.module_list:
if m.name != 'main': if m.name != 'main':
Call_Void(self.sim, "pairs::register_timer", [m.module_id + 1, m.name]) Call_Void(self.sim, "pairs::register_timer", [m.module_id + Timers.Offset, m.name])
class RegisterMarkers(FinalLowerable): class RegisterMarkers(FinalLowerable):
......
...@@ -3,6 +3,7 @@ from pairs.ir.block import Block ...@@ -3,6 +3,7 @@ from pairs.ir.block import Block
from pairs.ir.branches import Branch, Filter from pairs.ir.branches import Branch, Filter
from pairs.ir.functions import Call_Void from pairs.ir.functions import Call_Void
from pairs.ir.loops import For from pairs.ir.loops import For
from pairs.ir.timers import Timers
class Timestep: class Timestep:
...@@ -63,9 +64,9 @@ class Timestep: ...@@ -63,9 +64,9 @@ class Timestep:
_capture = self.sim._capture_statements _capture = self.sim._capture_statements
self.sim.capture_statements(False) self.sim.capture_statements(False)
block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [0]), block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [Timers.All]),
self.timestep_loop, self.timestep_loop,
Call_Void(self.sim, "pairs::stop_timer", [0])]) Call_Void(self.sim, "pairs::stop_timer", [Timers.All])])
self.sim.capture_statements(_capture) self.sim.capture_statements(_capture)
return block return block
...@@ -5,12 +5,14 @@ from pairs.ir.branches import Filter ...@@ -5,12 +5,14 @@ from pairs.ir.branches import Filter
from pairs.ir.cast import Cast from pairs.ir.cast import Cast
from pairs.ir.contexts import Contexts from pairs.ir.contexts import Contexts
from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
from pairs.ir.functions import Call_Void
from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.kernel import Kernel, KernelLaunch
from pairs.ir.lit import Lit from pairs.ir.lit import Lit
from pairs.ir.loops import For from pairs.ir.loops import For
from pairs.ir.module import ModuleCall from pairs.ir.module import ModuleCall
from pairs.ir.mutator import Mutator from pairs.ir.mutator import Mutator
from pairs.ir.scalars import ScalarOp from pairs.ir.scalars import ScalarOp
from pairs.ir.timers import Timers
from pairs.ir.types import Types from pairs.ir.types import Types
...@@ -34,6 +36,9 @@ class AddDeviceCopies(Mutator): ...@@ -34,6 +36,9 @@ class AddDeviceCopies(Mutator):
if isinstance(s, ModuleCall): if isinstance(s, ModuleCall):
copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host
clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device
new_stmts += [
Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
]
for array, action in s.module.arrays().items(): for array, action in s.module.arrays().items():
new_stmts += [CopyArray(s.sim, array, copy_context, action)] new_stmts += [CopyArray(s.sim, array, copy_context, action)]
...@@ -52,16 +57,27 @@ class AddDeviceCopies(Mutator): ...@@ -52,16 +57,27 @@ class AddDeviceCopies(Mutator):
if action != Actions.ReadOnly and var.device_flag: if action != Actions.ReadOnly and var.device_flag:
new_stmts += [CopyVar(s.sim, var, Contexts.Device, action)] new_stmts += [CopyVar(s.sim, var, Contexts.Device, action)]
new_stmts += [
Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
]
new_stmts.append(s) new_stmts.append(s)
if isinstance(s, ModuleCall): if isinstance(s, ModuleCall):
if s.module.run_on_device: if s.module.run_on_device:
new_stmts += [
Call_Void(ast_node.sim, "pairs::start_timer", [Timers.DeviceTransfers])
]
for var, action in s.module.variables().items(): for var, action in s.module.variables().items():
if action != Actions.ReadOnly and var.device_flag: if action != Actions.ReadOnly and var.device_flag:
new_stmts += [CopyVar(s.sim, var, Contexts.Host, action)] new_stmts += [CopyVar(s.sim, var, Contexts.Host, action)]
if self.module_resizes[s.module]: if self.module_resizes[s.module]:
new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, Actions.Ignore)] new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, Actions.Ignore)]
new_stmts += [
Call_Void(ast_node.sim, "pairs::stop_timer", [Timers.DeviceTransfers])
]
ast_node.stmts = new_stmts ast_node.stmts = new_stmts
return ast_node return ast_node
......
...@@ -2,6 +2,7 @@ from pairs.ir.block import Block ...@@ -2,6 +2,7 @@ from pairs.ir.block import Block
from pairs.ir.functions import Call_Void from pairs.ir.functions import Call_Void
from pairs.ir.module import ModuleCall from pairs.ir.module import ModuleCall
from pairs.ir.mutator import Mutator from pairs.ir.mutator import Mutator
from pairs.ir.timers import Timers
class AddModulesInstrumentation(Mutator): class AddModulesInstrumentation(Mutator):
...@@ -14,7 +15,7 @@ class AddModulesInstrumentation(Mutator): ...@@ -14,7 +15,7 @@ class AddModulesInstrumentation(Mutator):
if module.name == 'main': if module.name == 'main':
return ast_node return ast_node
timer_id = module.module_id + 1 timer_id = module.module_id + Timers.Offset
start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id]) start_timer = Call_Void(ast_node.sim, "pairs::start_timer", [timer_id])
stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id]) stop_timer = Call_Void(ast_node.sim, "pairs::stop_timer", [timer_id])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment