diff --git a/runtime/comm.hpp b/runtime/comm.hpp deleted file mode 100644 index 29c45a48ed4b99ecfeb722522c70b64a60f9f912..0000000000000000000000000000000000000000 --- a/runtime/comm.hpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "pairs.hpp" - -#pragma once - -namespace pairs { - -template<int ndims> -void initDomain(PairsSimulation<ndims> *pairs, int *argc, char ***argv, real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) { - pairs->initDomain(argc, argv, xmin, xmax, ymin, ymax, zmin, zmax); -} - -template<int ndims> -void communicateSizes(PairsSimulation<ndims> *pairs, int dim, const int *send_sizes, int *recv_sizes) { - pairs->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes); -} - -template<int ndims> -void communicateData( - PairsSimulation<ndims> *pairs, int dim, int elem_size, - const real_t *send_buf, const int *send_offsets, const int *nsend, - real_t *recv_buf, const int *recv_offsets, const int *nrecv) { - - pairs->getDomainPartitioner()->communicateData(dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv); -} - -template<int ndims> -void fillCommunicationArrays(PairsSimulation<ndims> *pairs, int neighbor_ranks[], int pbc[], real_t subdom[]) { - pairs->getDomainPartitioner()->fillArrays(neighbor_ranks, pbc, subdom); -} - -} diff --git a/runtime/domain/regular_6d_stencil.hpp b/runtime/domain/regular_6d_stencil.hpp index e65058d00fb63d998a8a0c8555e0e80548368cad..6414e4cb226452221285d0c6ce528986e398f1ed 100644 --- a/runtime/domain/regular_6d_stencil.hpp +++ b/runtime/domain/regular_6d_stencil.hpp @@ -121,21 +121,26 @@ public: const real_t *send_buf, const int *send_offsets, const int *nsend, real_t *recv_buf, const int *recv_offsets, const int *nrecv) { + const real_t *send_prev = &send_buf[send_offsets[dim * 2 + 0] * elem_size]; + const real_t *send_next = &send_buf[send_offsets[dim * 2 + 1] * elem_size]; + real_t *recv_prev = &recv_buf[recv_offsets[dim * 2 + 0] * elem_size]; + real_t *recv_next = &recv_buf[recv_offsets[dim * 2 + 1] * elem_size]; + if(prev[dim] != rank) { - MPI_Send(&send_buf[send_offsets[dim * 2 + 0]], nsend[dim * 2 + 0] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD); - MPI_Recv(&recv_buf[recv_offsets[dim * 2 + 0]], nrecv[dim * 2 + 0] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + MPI_Send(send_prev, nsend[dim * 2 + 0] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD); + MPI_Recv(recv_prev, nrecv[dim * 2 + 0] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } else { for(int i = 0; i < nsend[dim * 2 + 0] * elem_size; i++) { - recv_buf[recv_offsets[dim * 2 + 0] + i] = send_buf[send_offsets[dim * 2 + 0] + i]; + recv_prev[i] = send_prev[i]; } } if(next[dim] != rank) { - MPI_Send(&send_buf[send_offsets[dim * 2 + 1]], nsend[dim * 2 + 1] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD); - MPI_Recv(&recv_buf[recv_offsets[dim * 2 + 1]], nrecv[dim * 2 + 1] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + MPI_Send(send_next, nsend[dim * 2 + 1] * elem_size, MPI_DOUBLE, next[dim], 0, MPI_COMM_WORLD); + MPI_Recv(recv_next, nrecv[dim * 2 + 1] * elem_size, MPI_DOUBLE, prev[dim], 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } else { for(int i = 0; i < nsend[dim * 2 + 1] * elem_size; i++) { - recv_buf[recv_offsets[dim * 2 + 1] + i] = send_buf[send_offsets[dim * 2 + 1] + i]; + recv_next[i] = send_next[i]; } } } diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp index 91ca0ae690ee333d484c06fc180b9c376865f640..690e707da0bc59e4fa538443f38a2e0c2ffea622 100644 --- a/runtime/pairs.hpp +++ b/runtime/pairs.hpp @@ -25,7 +25,7 @@ # define PAIRS_EXCEPTION(a) #endif -#define PAIRS_ERROR(...) fprintf(stderr, __VA_ARGS__) +#define PAIRS_ERROR(...) fprintf(stderr, __VA_ARGS__) namespace pairs { @@ -515,6 +515,23 @@ public: pairs::copy_to_host(prop.getDevicePointer(), prop.getHostPointer(), prop.getTotalSize()); } } + + void communicateSizes(int dim, const int *send_sizes, int *recv_sizes) { + this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes); + PAIRS_DEBUG("send_sizes=[%d, %d], recv_sizes=[%d, %d]\n", send_sizes[dim * 2 + 0], recv_sizes[dim * 2 + 1], recv_sizes[dim * 2 + 0], recv_sizes[dim * 2 + 1]); + } + + void communicateData( + int dim, int elem_size, + const real_t *send_buf, const int *send_offsets, const int *nsend, + real_t *recv_buf, const int *recv_offsets, const int *nrecv) { + + this->getDomainPartitioner()->communicateData(dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv); + } + + void fillCommunicationArrays(int neighbor_ranks[], int pbc[], real_t subdom[]) { + this->getDomainPartitioner()->fillArrays(neighbor_ranks, pbc, subdom); + } }; } diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 34e6709417e6c2426e5b44de8dc5ca77201fbf3c..bd41f7ac9a5fbd20fedb132f5fe67d8d2fbfee07 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -55,7 +55,6 @@ class CGen: self.print("#include <stdlib.h>") self.print("//---") self.print("#include \"runtime/pairs.hpp\"") - self.print("#include \"runtime/comm.hpp\"") self.print("#include \"runtime/read_from_file.hpp\"") self.print("#include \"runtime/vtk.hpp\"") @@ -531,8 +530,15 @@ class CGen: return f"e{ast_node.id()}" if isinstance(ast_node, Call): - extra_params = [] if ast_node.name() != "pairs::initDomain" else ["&argc", "&argv"] - params = ", ".join(["pairs"] + extra_params + [str(self.generate_expression(p)) for p in ast_node.parameters()]) + extra_params = [] + + if ast_node.name().startswith("pairs::"): + extra_params += ["pairs"] + + if ast_node.name() == "pairs->initDomain": + extra_params += ["&argc", "&argv"] + + params = ", ".join(extra_params + [str(self.generate_expression(p)) for p in ast_node.parameters()]) return f"{ast_node.name()}({params})" if isinstance(ast_node, Cast): diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index 4b2931735a6662e0d1b44bcbc2211253dd359d41..b95000c81e6e3d9576e4c66b8fe75b5b3f0d30e3 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -79,7 +79,7 @@ class CommunicateSizes(Lowerable): @pairs_inline def lower(self): - Call_Void(self.sim, "pairs::communicateSizes", [self.step, self.comm.nsend, self.comm.nrecv]) + Call_Void(self.sim, "pairs->communicateSizes", [self.step, self.comm.nsend, self.comm.nrecv]) class CommunicateData(Lowerable): @@ -93,7 +93,7 @@ class CommunicateData(Lowerable): @pairs_inline def lower(self): elem_size = sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list]) - Call_Void(self.sim, "pairs::communicateData", [self.step, elem_size, + Call_Void(self.sim, "pairs->communicateData", [self.step, elem_size, self.comm.send_buffer, self.comm.send_offsets, self.comm.nsend, self.comm.recv_buffer, self.comm.recv_offsets, self.comm.nrecv]) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index dad95eddf3e8149776b841198d4a0e13ea20590f..2e73b7bbd0fb2ba041114f3ac1d10ea21d87ca12 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -245,8 +245,8 @@ class Simulation: self.capture_statements(False) grid_array = [[self.grid.min(d), self.grid.max(d)] for d in range(self.ndims())] self.setups.add_statement([ - Call_Void(self, "pairs::initDomain", [param for delim in grid_array for param in delim]), - Call_Void(self, "pairs::fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]) + Call_Void(self, "pairs->initDomain", [param for delim in grid_array for param in delim]), + Call_Void(self, "pairs->fillCommunicationArrays", [dom_part.neighbor_ranks, dom_part.pbc, dom_part.subdom]) ]) self.capture_statements() # TODO: check if this is actually required