From c2e201c37fa8e9767c3895f19f5101faa51b1b1c Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Tue, 28 Feb 2023 18:17:33 +0100
Subject: [PATCH] fixes

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/pairs.cpp     | 13 +++++++++++--
 src/pairs/sim/comm.py |  5 +++--
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index 09d8cfd..98d7418 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -110,8 +110,6 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv
     auto nrecv_id = getArrayByHostPointer(recv_sizes).getId();
 
     copyArrayToHost(nsend_id);
-    array_flags->setHostFlag(nsend_id);
-    array_flags->clearDeviceFlag(nsend_id);
     array_flags->setHostFlag(nrecv_id);
     array_flags->clearDeviceFlag(nrecv_id);
     this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes);
@@ -123,7 +121,18 @@ void PairsSimulation::communicateData(
     const real_t *send_buf, const int *send_offsets, const int *nsend,
     real_t *recv_buf, const int *recv_offsets, const int *nrecv) {
 
+    auto send_buf_id = getArrayByHostPointer(send_buf).getId();
     auto recv_buf_id = getArrayByHostPointer(recv_buf).getId();
+    auto send_offsets_id = getArrayByHostPointer(send_offsets).getId();
+    auto recv_offsets_id = getArrayByHostPointer(recv_offsets).getId();
+    auto nsend_id = getArrayByHostPointer(nsend).getId();
+    auto nrecv_id = getArrayByHostPointer(nrecv).getId();
+
+    copyArrayToHost(send_buf_id);
+    copyArrayToHost(send_offsets_id);
+    copyArrayToHost(recv_offsets_id);
+    copyArrayToHost(nsend_id);
+    copyArrayToHost(nrecv_id);
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
     this->getDomainPartitioner()->communicateData(dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 87ee4c5..76d67e7 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -115,8 +115,8 @@ class DetermineGhostParticles(Lowerable):
         self.spacing = spacing
         self.sim.add_statement(self)
 
-    #@pairs_host_block
-    @pairs_device_block
+    @pairs_host_block
+    #@pairs_device_block
     def lower(self):
         nsend_all = self.comm.nsend_all
         nsend = self.comm.nsend
@@ -128,6 +128,7 @@ class DetermineGhostParticles(Lowerable):
         ghost_or_exchg = "exchange" if is_exchange else "ghost"
         self.sim.module_name(f"determine_{ghost_or_exchg}_particles{self.step}")
         self.sim.check_resize(self.comm.send_capacity, nsend)
+        #self.sim.check_resize(self.comm.send_capacity, nsend_all)
 
         for j in self.comm.dom_part.step_indexes(self.step):
             nsend[j].set(0)
-- 
GitLab