From 66c41fd81ad9a9ff00d875cc16addf03cfe8db2a Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Sat, 3 Sep 2022 02:57:58 +0200
Subject: [PATCH] Still need to fix some of the runtime comm code

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/comm.hpp                       |  4 ++--
 runtime/domain/domain_partitioning.hpp | 14 ++++++++------
 runtime/domain/regular_6d_stencil.hpp  | 21 ++++++++++++---------
 runtime/pairs.hpp                      | 11 ++++++-----
 src/pairs/code_gen/cgen.py             |  5 +++--
 5 files changed, 31 insertions(+), 24 deletions(-)

diff --git a/runtime/comm.hpp b/runtime/comm.hpp
index d9a7abb..29c45a4 100644
--- a/runtime/comm.hpp
+++ b/runtime/comm.hpp
@@ -5,8 +5,8 @@
 namespace pairs {
 
 template<int ndims>
-void initDomain(PairsSimulation<ndims> *pairs, real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
-    pairs->initDomain(xmin, xmax, ymin, ymax, zmin, zmax);
+void initDomain(PairsSimulation<ndims> *pairs, int *argc, char ***argv, real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
+    pairs->initDomain(argc, argv, xmin, xmax, ymin, ymax, zmin, zmax);
 }
 
 template<int ndims>
diff --git a/runtime/domain/domain_partitioning.hpp b/runtime/domain/domain_partitioning.hpp
index 3936a50..9d22df2 100644
--- a/runtime/domain/domain_partitioning.hpp
+++ b/runtime/domain/domain_partitioning.hpp
@@ -40,7 +40,7 @@ public:
 };
 
 template<int ndims>
-class DimensionRanges : DomainPartitioner<ndims> {
+class DimensionRanges : public DomainPartitioner<ndims> {
 protected:
     int nranks[ndims];
     int prev[ndims];
@@ -68,14 +68,14 @@ public:
     void communicateSizes(int dim, const int *send_sizes, int *recv_sizes) {
         if(prev[dim] != this->getRank()) {
             MPI_Send(&send_sizes[dim * 2 + 0], 1, MPI_INT, prev[dim], 0, MPI_COMM_WORLD);
-            MPI_Recv(&recv_sizes[dim * 2 + 0], 1, MPI_INT, next[dim], 0, MPI_COMM_WORLD);
+            MPI_Recv(&recv_sizes[dim * 2 + 0], 1, MPI_INT, next[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
         } else {
             recv_sizes[dim * 2 + 0] = send_sizes[dim * 2 + 0];
         }
 
         if(next[dim] != this->getRank()) {
             MPI_Send(&send_sizes[dim * 2 + 1], 1, MPI_INT, next[dim], 0, MPI_COMM_WORLD);
-            MPI_Recv(&recv_sizes[dim * 2 + 1], 1, MPI_INT, prev[dim], 0, MPI_COMM_WORLD);
+            MPI_Recv(&recv_sizes[dim * 2 + 1], 1, MPI_INT, prev[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
         } else {
             recv_sizes[dim * 2 + 1] = send_sizes[dim * 2 + 1];
         }
@@ -88,7 +88,7 @@ public:
 
         if(prev[dim] != this->getRank()) {
             MPI_Send(&send_buf[send_offsets[dim * 2 + 0]], nsend[dim * 2 + 0] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD);
-            MPI_Recv(&recv_buf[recv_offsets[dim * 2 + 0]], nrecv[dim * 2 + 0] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD);
+            MPI_Recv(&recv_buf[recv_offsets[dim * 2 + 0]], nrecv[dim * 2 + 0] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
         } else {
             for(int i = 0; i < nsend[dim * 2 + 0] * elem_size; i++) {
                 recv_buf[recv_offsets[dim * 2 + 0] + i] = send_buf[send_offsets[dim * 2 + 0] + i];
@@ -97,7 +97,7 @@ public:
 
         if(next[dim] != this->getRank()) {
             MPI_Send(&send_buf[send_offsets[dim * 2 + 1]], nsend[dim * 2 + 1] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD);
-            MPI_Recv(&recv_buf[recv_offsets[dim * 2 + 1]], nrecv[dim * 2 + 1] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD);
+            MPI_Recv(&recv_buf[recv_offsets[dim * 2 + 1]], nrecv[dim * 2 + 1] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
         } else {
             for(int i = 0; i < nsend[dim * 2 + 1] * elem_size; i++) {
                 recv_buf[recv_offsets[dim * 2 + 1] + i] = send_buf[send_offsets[dim * 2 + 1] + i];
@@ -107,6 +107,8 @@ public:
 };
 
 template<int ndims>
-class ListOfBoxes : DomainPartitioner<ndims> {};
+class ListOfBoxes : public DomainPartitioner<ndims> {};
+
+class DomainPartitioner<3>;
 
 }
diff --git a/runtime/domain/regular_6d_stencil.hpp b/runtime/domain/regular_6d_stencil.hpp
index 24d3423..6ec03ee 100644
--- a/runtime/domain/regular_6d_stencil.hpp
+++ b/runtime/domain/regular_6d_stencil.hpp
@@ -1,14 +1,17 @@
-//---
-#include "pairs.hpp"
 #include "domain_partitioning.hpp"
 
 #pragma once
 
+typedef double real_t;
+
 namespace pairs {
 
 template <int ndims>
-class Regular6DStencil : DimensionRanges<ndims> {
+class Regular6DStencil : public DimensionRanges<ndims> {
 public:
+    Regular6DStencil(real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) :
+        DimensionRanges<ndims>(xmin, xmax, ymin, ymax, zmin, zmax) {}
+
     void setConfig() {
         static_assert(ndims == 3, "setConfig() only implemented for three dimensions!");
         real_t area[ndims];
@@ -20,9 +23,9 @@ public:
             best_surf += 2.0 * area[d];
         }
 
-        for(int i = 1; i < world_size; i++) {
-            if(world_size % i == 0) {
-                const int rem_yz = world_size / i;
+        for(int i = 1; i < this->world_size; i++) {
+            if(this->world_size % i == 0) {
+                const int rem_yz = this->world_size / i;
                 for(int j = 0; j < rem_yz; j++) {
                     if(rem_yz % j == 0) {
                         const int k = rem_yz / j;
@@ -31,7 +34,7 @@ public:
                             this->nranks[0] = i;
                             this->nranks[1] = j;
                             this->nranks[2] = k;
-                            bestsurf = surf;
+                            best_surf = surf;
                         }
                     }
                 }
@@ -61,10 +64,10 @@ public:
             this->subdom_max[d] = this->subdom_min[d] + rank_length[d];
         }
 
-        MPI_Comm_free(cartesian);
+        MPI_Comm_free(&cartesian);
     }
 
-    void initialize(int *argc, const char **argv) {
+    void initialize(int *argc, char ***argv) {
         MPI_Init(argc, argv);
         MPI_Comm_size(MPI_COMM_WORLD, &(this->world_size));
         MPI_Comm_rank(MPI_COMM_WORLD, &(this->rank));
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index e0d00da..05aed09 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -10,7 +10,7 @@
 #   include "devices/dummy.hpp"
 #endif
 
-#include "domain/domain_partitioning.hpp"
+#include "domain/regular_6d_stencil.hpp"
 
 #pragma once
 
@@ -255,7 +255,8 @@ public:
 template<int ndims>
 class PairsSimulation {
 private:
-    DomainPartitioner<ndims> *dom_part;
+    Regular6DStencil<ndims> *dom_part;
+    //DomainPartitioner<ndims> *dom_part;
     std::vector<Property> properties;
     std::vector<Array> arrays;
     DeviceFlags *prop_flags, *array_flags;
@@ -274,14 +275,14 @@ public:
         delete array_flags;
     }
 
-    void initDomain(real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
+    void initDomain(int *argc, char ***argv, real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
         if(dom_part_type == DimRanges) {
-            dom_part = new DimensionRanges<ndims>(xmin, xmax, ymin, ymax, zmin, zmax);
+            dom_part = new Regular6DStencil<ndims>(xmin, xmax, ymin, ymax, zmin, zmax);
         } else {
             PAIRS_EXCEPTION("Domain partitioning type not implemented!\n");
         }
 
-        dom_part->initialize();
+        dom_part->initialize(argc, argv);
     }
 
     DomainPartitioner<ndims> *getDomainPartitioner() { return dom_part; }
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 68bab60..34e6709 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -91,7 +91,7 @@ class CGen:
             ndims = module.sim.ndims()
             nprops = module.sim.properties.nprops()
             narrays = module.sim.arrays.narrays()
-            self.print("int main() {")
+            self.print("int main(int argc, char **argv) {")
             self.print(f"    PairsSimulation<{ndims}> *pairs = new PairsSimulation<{ndims}>({nprops}, {narrays}, DimRanges);")
             self.generate_statement(module.block)
             self.print("    return 0;")
@@ -531,7 +531,8 @@ class CGen:
             return f"e{ast_node.id()}"
 
         if isinstance(ast_node, Call):
-            params = ", ".join(["pairs"] + [str(self.generate_expression(p)) for p in ast_node.parameters()])
+            extra_params = [] if ast_node.name() != "pairs::initDomain" else ["&argc", "&argv"]
+            params = ", ".join(["pairs"] + extra_params + [str(self.generate_expression(p)) for p in ast_node.parameters()])
             return f"{ast_node.name()}({params})"
 
         if isinstance(ast_node, Cast):
-- 
GitLab