From a301e91ad8c91707e2ec6886aef7e01569589894 Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@alex1.nhr.fau.de>
Date: Sun, 24 Nov 2024 18:21:46 +0100
Subject: [PATCH] Fix FindwaLBerla issue, modularize communication

---
 CMakeLists.txt              |  2 +-
 FindwaLBerla.cmake          |  2 +-
 examples/main.cpp           |  2 ++
 src/pairs/code_gen/cgen.py  | 12 ++++++++++--
 src/pairs/sim/simulation.py | 17 +++++++++++++----
 5 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index b7fd6a6..3f0b768 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -29,8 +29,8 @@ option(GENERATE_WHOLE_PROGRAM "GENERATE_WHOLE_PROGRAM" OFF)
 
 if(USE_WALBERLA)
     SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_WALBERLA_LOAD_BALANCING -DUSE_WALBERLA")
+    set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}")
     find_package(waLBerla REQUIRED)
-    add_subdirectory(${walberla_SOURCE_DIR} ${walberla_BINARY_DIR} EXCLUDE_FROM_ALL)
     waLBerla_import()
 endif()
 
diff --git a/FindwaLBerla.cmake b/FindwaLBerla.cmake
index 1bbb2d2..3eb44c5 100644
--- a/FindwaLBerla.cmake
+++ b/FindwaLBerla.cmake
@@ -3,7 +3,7 @@ set( WALBERLA_DIR    WALBERLA_DIR-NOTFOUND   CACHE  PATH  "waLBerla path"  )
 if ( WALBERLA_DIR )
     # WALBERLA_DIR has to point to the waLBerla source directory
     # this command builds waLBerla (again) in the current build directory in the subfolder "walberla" (second argument)
-    add_subdirectory( ${WALBERLA_DIR} walberla  )
+    add_subdirectory( ${WALBERLA_DIR} walberla EXCLUDE_FROM_ALL)
     
     waLBerla_import()
     # Adds the 'src' and 'tests' directory of current app
diff --git a/examples/main.cpp b/examples/main.cpp
index 5f8473b..ea9094f 100644
--- a/examples/main.cpp
+++ b/examples/main.cpp
@@ -79,6 +79,8 @@ int main(int argc, char **argv) {
         //     int idx = pairs_acc->uidToIdx(pUid);
         //     std::cout<< "Tracked particle is now in rank " << rank << " --- " << pairs_acc->getPosition(idx)<< std::endl;
         // }
+
+        pairs_sim->communicate(t);
         
         if (pIsLocalInMyRank(pUid)){
             int idx = pairs_acc->uidToIdx(pUid);
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 9eef2d3..800d047 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -366,7 +366,7 @@ class CGen:
 
         self.print.end()
 
-    def generate_library(self, initialize_module, create_domain_module, setup_sim_module,  do_timestep_module, reverse_comm_module):
+    def generate_library(self, initialize_module, create_domain_module, setup_sim_module,  do_timestep_module, reverse_comm_module, communicate_module):
         self.generate_interfaces()
         # Generate CUDA/CPP file with modules
         ext = ".cu" if self.target.is_gpu() else ".cpp"
@@ -394,7 +394,7 @@ class CGen:
             self.generate_kernel(kernel)
 
         for module in self.sim.modules():
-            if module.name not in ['initialize', 'create_domain', 'setup_sim', 'do_timestep', 'reverse_comm']:
+            if module.name not in ['initialize', 'create_domain', 'setup_sim', 'do_timestep', 'reverse_comm', 'communicate']:
                 self.generate_module(module)
 
         self.print.end()
@@ -483,6 +483,14 @@ class CGen:
         self.print("}")
         self.print("")
 
+        self.print("void communicate(int timestep) {")
+        self.print("    pobj->sim_timestep = timestep;")
+        self.print.add_indent(4)
+        self.generate_statement(communicate_module.block)
+        self.print.add_indent(-4)
+        self.print("}")
+        self.print("")
+
         self.print("void end() {")
         self.print("    pairs::print_timers(pairs_runtime);")
         self.print("    pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);")
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index c23c377..acab2a5 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -448,9 +448,17 @@ class Simulation:
         every_reneighbor_params = {'every': self.reneighbor_frequency}
 
         # First steps executed during each time-step in the simulation
-        timestep_procedures = self.pre_step_functions + [
+        timestep_procedures = self.pre_step_functions 
+
+        comm_routine = [
             (comm.exchange(), every_reneighbor_params),
-            (comm.borders(), comm.synchronize(), every_reneighbor_params),
+            (comm.borders(), comm.synchronize(), every_reneighbor_params)
+            ]
+        
+        if self._generate_whole_program:
+            timestep_procedures += comm_routine
+
+        timestep_procedures +=    [
             (BuildCellLists(self, self.cell_lists), every_reneighbor_params),
             (PartitionCellLists(self, self.cell_lists), every_reneighbor_params)
         ]
@@ -551,13 +559,14 @@ class Simulation:
 
             setup_sim_module = Module(self, name='setup_sim', block=setup_sim)
             do_timestep_module = Module(self, name='do_timestep', block=timestep.as_block())
+            communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).as_block())
 
-            modules_list = [initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module]
+            modules_list = [initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module]
 
             transformations = Transformations(modules_list, self._target)
             transformations.apply_all()
 
             # Generate library
-            self.code_gen.generate_library(initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module)
+            self.code_gen.generate_library(initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module)
 
         self.code_gen.generate_interfaces()
-- 
GitLab