diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp index 2fcb2d60cadeffe7f722a75e7a86474ef9a759a9..09d8cfd2bb64565d313f2615445d7dbdc2024df8 100644 --- a/runtime/pairs.cpp +++ b/runtime/pairs.cpp @@ -40,6 +40,12 @@ Array &PairsSimulation::getArrayByName(std::string name) { return *a; } +Array &PairsSimulation::getArrayByHostPointer(const void *h_ptr) { + auto a = std::find_if(arrays.begin(), arrays.end(), [h_ptr](Array a) { return a.getHostPointer() == h_ptr; }); + PAIRS_ASSERT(a != std::end(arrays)); + return *a; +} + void PairsSimulation::addProperty(Property prop) { int id = prop.getId(); auto p = std::find_if(properties.begin(), properties.end(), [id](Property p) { return p.getId() == id; }); @@ -100,6 +106,14 @@ void PairsSimulation::copyPropertyToHost(Property &prop) { } void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv_sizes) { + auto nsend_id = getArrayByHostPointer(send_sizes).getId(); + auto nrecv_id = getArrayByHostPointer(recv_sizes).getId(); + + copyArrayToHost(nsend_id); + array_flags->setHostFlag(nsend_id); + array_flags->clearDeviceFlag(nsend_id); + array_flags->setHostFlag(nrecv_id); + array_flags->clearDeviceFlag(nrecv_id); this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes); PAIRS_DEBUG("send_sizes=[%d, %d], recv_sizes=[%d, %d]\n", send_sizes[dim * 2 + 0], send_sizes[dim * 2 + 1], recv_sizes[dim * 2 + 0], recv_sizes[dim * 2 + 1]); } @@ -109,6 +123,9 @@ void PairsSimulation::communicateData( const real_t *send_buf, const int *send_offsets, const int *nsend, real_t *recv_buf, const int *recv_offsets, const int *nrecv) { + auto recv_buf_id = getArrayByHostPointer(recv_buf).getId(); + array_flags->setHostFlag(recv_buf_id); + array_flags->clearDeviceFlag(recv_buf_id); this->getDomainPartitioner()->communicateData(dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv); /* diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp index f96ec88dbf21eaede0ca92e1c33582c31da79262..7130d6c4e8df9cf55db35dfe5271c27a25c5f291 100644 --- a/runtime/pairs.hpp +++ b/runtime/pairs.hpp @@ -57,6 +57,7 @@ public: Array &getArray(array_t id); Array &getArrayByName(std::string name); + Array &getArrayByHostPointer(const void *h_ptr); template<typename T_ptr> void addProperty( property_t id, std::string name, T_ptr **h_ptr, std::nullptr_t, PropertyType type, layout_t layout, size_t sx, size_t sy = 1); diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 548326132ecdc67ec7b83093c01d8c9b9d72c946..51f026fd6221a09143e7cc3b31026ab0761cbaa7 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -65,13 +65,15 @@ class Module(ASTNode): return self._host_references def properties_to_synchronize(self): - return {p for p in self._properties if self._properties[p][0] == 'r'} + #return {p for p in self._properties if self._properties[p][0] == 'r'} + return {p for p in self._properties} def write_properties(self): return {p for p in self._properties if 'w' in self._properties[p]} def arrays_to_synchronize(self): - return {a for a in self._arrays if a.sync() and self._arrays[a][0] == 'r'} + #return {a for a in self._arrays if a.sync() and self._arrays[a][0] == 'r'} + return {a for a in self._arrays if a.sync()} def write_arrays(self): return {a for a in self._arrays if a.sync() and 'w' in self._arrays[a]} diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index b2dd3c259968f282520d8c2b128be882aaa9aef5..87ee4c591e47efc269fd9cc7d7ab313216c3d909 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -115,6 +115,7 @@ class DetermineGhostParticles(Lowerable): self.spacing = spacing self.sim.add_statement(self) + #@pairs_host_block @pairs_device_block def lower(self): nsend_all = self.comm.nsend_all @@ -189,6 +190,7 @@ class PackGhostParticles(Lowerable): def get_elems_per_particle(self): return sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list]) + #@pairs_host_block @pairs_device_block def lower(self): send_buffer = self.comm.send_buffer @@ -230,6 +232,7 @@ class UnpackGhostParticles(Lowerable): def get_elems_per_particle(self): return sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list]) + #@pairs_host_block @pairs_device_block def lower(self): nlocal = self.sim.nlocal @@ -285,6 +288,7 @@ class RemoveExchangedParticles_part2(Lowerable): self.prop_list = prop_list self.sim.add_statement(self) + #@pairs_host_block @pairs_device_block def lower(self): self.sim.module_name("remove_exchanged_particles_pt2")