From 61ee1fda8dc893c42edb11e2f29a10946a4b59fc Mon Sep 17 00:00:00 2001
From: behzad <safaei_b@hotmail.com>
Date: Fri, 20 Sep 2024 15:35:29 +0200
Subject: [PATCH] create_domain/use_domain, do_timestep

---
 examples/main.cpp               | 33 +++++++++++++++++++----
 runtime/domain/block_forest.hpp | 22 +++++++++++++++
 runtime/pairs.cpp               | 15 ++++++++++-
 runtime/pairs.hpp               | 29 ++++++++++++++++++++
 src/pairs/code_gen/cgen.py      | 47 ++++++++++++++++++++++++++-------
 src/pairs/sim/simulation.py     | 41 ++++++++++++++++------------
 src/pairs/sim/timestep.py       |  8 +++---
 7 files changed, 158 insertions(+), 37 deletions(-)

diff --git a/examples/main.cpp b/examples/main.cpp
index f3841e3..927cb50 100644
--- a/examples/main.cpp
+++ b/examples/main.cpp
@@ -4,11 +4,34 @@
 
 int main(int argc, char **argv) {
     PairsSimulation *ps = new PairsSimulation();
-    std::cout << "initialize" << std::endl;
-    ps->initialize(argc, argv);
-    std::cout << "do_timestep" << std::endl;
-    ps->do_timestep();
-    std::cout << "end" << std::endl;
+
+    // Create forest (make sure to use_domain(forest)) ----------------------------------------------
+    walberla::math::AABB domain(0, 0, 0, 0.1, 0.1, 0.1);
+    std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance();
+    mpiManager->initializeMPI(&argc, &argv);
+    mpiManager->useWorldComm();
+    auto procs = mpiManager->numProcesses();
+    auto block_config = walberla::Vector3<int>(2, 2, 1);
+    auto ref_level = 0;
+    std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest(
+            domain, block_config, walberla::Vector3<bool>(true, true, false), procs, ref_level);
+    //-----------------------------------------------------------------------------------------------
+
+    // initialize pairs data structures ----------------------------------------------
+    ps->initialize();
+
+    // either create new domain or use an existing one ----------------------------------------
+    // ps->create_domain(argc, argv);
+    ps->use_domain(forest);
+
+    // setup particles, setup functions, and the cell list stencil-------------------------------
+    ps->setup_sim();
+
+    for (int i=0; i<10000; ++i){
+        ps->do_timestep(i);
+    }
+
     ps->end();
+
     return 0;
 }
diff --git a/runtime/domain/block_forest.hpp b/runtime/domain/block_forest.hpp
index a92eb9b..c2ee68f 100644
--- a/runtime/domain/block_forest.hpp
+++ b/runtime/domain/block_forest.hpp
@@ -44,6 +44,28 @@ public:
         subdom = new real_t[ndims * 2];
     }
 
+    BlockForest(PairsRuntime *ps_, std::shared_ptr<walberla::BlockForest> bf) :
+        forest(bf),
+        DomainPartitioner(bf->getDomain().xMin(), bf->getDomain().xMax(),
+                        bf->getDomain().yMin(), bf->getDomain().yMax(),
+                        bf->getDomain().zMin(), bf->getDomain().zMax()), 
+        ps(ps_), 
+        globalPBC{bf->isXPeriodic(), bf->isYPeriodic(), bf->isZPeriodic()} 
+        {
+            subdom = new real_t[ndims * 2];
+            balance_workload = 0;
+
+            mpiManager = walberla::mpi::MPIManager::instance();
+            world_size = mpiManager->numProcesses();
+            rank = mpiManager->rank();
+            this->info = make_shared<walberla::blockforest::InfoCollection>();
+
+            if(balance_workload) {
+                this->initializeWorkloadBalancer();
+            }
+
+        }
+
     ~BlockForest() {
         delete[] subdom;
     }
diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index 8127602..e89fff1 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -16,6 +16,18 @@ void PairsRuntime::initDomain(
     int *argc, char ***argv,
     real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, bool pbcx, bool pbcy, bool pbcz) {
 
+    int mpi_initialized=0;
+    MPI_Initialized(&mpi_initialized);
+    
+    if(mpi_initialized){ 
+        PAIRS_ERROR("MPI is already initialized!\n"); 
+        exit(-1);
+    }
+    if(dom_part){ 
+        PAIRS_ERROR("DomainPartitioner already exists!\n"); 
+        exit(-1);
+    }
+
     if(dom_part_type == RegularPartitioning) {
         const int flags[] = {1, 1, 1};
         dom_part = new Regular6DStencil(xmin, xmax, ymin, ymax, zmin, zmax, flags);
@@ -25,7 +37,8 @@ void PairsRuntime::initDomain(
     } else if(dom_part_type == BlockForestPartitioning) {
         dom_part = new BlockForest(this, xmin, xmax, ymin, ymax, zmin, zmax, pbcx, pbcy, pbcz);
     } else {
-        PAIRS_EXCEPTION("Domain partitioning type not implemented!\n");
+        PAIRS_ERROR("Domain partitioning type not implemented!\n");
+        exit(-1);
     }
 
     dom_part->initialize(argc, argv);
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index 2bcf71f..0c9f55e 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -46,6 +46,7 @@ public:
         int narrays_,
         DomainPartitioners dom_part_type_) {
 
+        dom_part = nullptr;
         dom_part_type = dom_part_type_;
         prop_flags = new DeviceFlags(nprops_);
         contact_prop_flags = new DeviceFlags(ncontactprops_);
@@ -304,6 +305,9 @@ public:
         real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax,
         bool pbcx = 0, bool pbcy = 0, bool pbcz = 0);
 
+    template<typename Domain_T>
+    void useDomain(std::shared_ptr<Domain_T> domain_ptr);
+
     void updateDomain() { dom_part->update(); }
 
     DomainPartitioner *getDomainPartitioner() { return dom_part; }
@@ -340,6 +344,31 @@ public:
     }
 };
 
+template<typename Domain_T>
+void PairsRuntime::useDomain(std::shared_ptr<Domain_T> domain_ptr){
+    
+    if(dom_part){ 
+        PAIRS_ERROR("DomainPartitioner already exists!\n"); 
+        exit(-1);
+    }
+
+    if(dom_part_type == RegularPartitioning) {
+        PAIRS_ERROR("useDomain not implemented for Regular6DStencil!\n");
+        exit(-1);
+
+    } else if(dom_part_type == RegularXYPartitioning) {        
+        PAIRS_ERROR("useDomain not implemented for Regular6DStencil!\n");
+        exit(-1);
+
+    } else if(dom_part_type == BlockForestPartitioning) {
+        dom_part = new BlockForest(this, domain_ptr);
+
+    } else {
+        PAIRS_ERROR("Domain partitioning type not implemented!\n");
+        exit(-1);
+    }
+}
+
 template<typename T_ptr>
 void PairsRuntime::addArray(array_t id, std::string name, T_ptr **h_ptr, std::nullptr_t, size_t size) {
     PAIRS_ASSERT(size > 0);
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index fec4fd8..1a3103a 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -170,13 +170,13 @@ class CGen:
 
         self.print("// Module headers")
         for module in self.sim.modules():
-            self.print(f"void {module.name}(struct pairs_objects *pobj);")
+            self.print(f"void {module.name}(struct PairsObjects *pobj);")
 
         self.print("")
 
     def generate_pairs_object_structure(self):
         self.print("")
-        self.print("struct pairs_objects {")
+        self.print("struct PairsObjects {")
         self.print.add_indent(4)
 
         self.print("// Arrays")
@@ -247,7 +247,7 @@ class CGen:
 
         self.print.end()
 
-    def generate_library(self, initialize_module, do_timestep_module):
+    def generate_library(self, initialize_module, create_domain_module, setup_sim_module,  do_timestep_module):
         # Generate CUDA/CPP file with modules
         ext = ".cu" if self.target.is_gpu() else ".cpp"
         self.print = Printer(self.ref + ext)
@@ -259,7 +259,7 @@ class CGen:
             self.generate_kernel(kernel)
 
         for module in self.sim.modules():
-            if module.name not in ['initialize', 'do_timestep']:
+            if module.name not in ['initialize', 'create_domain', 'setup_sim', 'do_timestep']:
                 self.generate_module(module)
 
         self.print.end()
@@ -282,22 +282,46 @@ class CGen:
         self.print("class PairsSimulation {")
         self.print("private:")
         self.print("    PairsRuntime *pairs_runtime;")
-        self.print("    struct pairs_objects *pobj;")
+        self.print("    struct PairsObjects *pobj;")
         self.print("public:")
-        self.print("    void initialize(int argc, char **argv) {")
+
+        self.print("    void initialize() {")
         self.print(f"        pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
-        self.print(f"        pobj = new pairs_objects();")
+        self.print(f"        pobj = new PairsObjects();")
         self.print.add_indent(4)
         self.generate_statement(initialize_module.block)
         self.print.add_indent(-4)
         self.print("    }")
         self.print("")
-        self.print("    void do_timestep() {")
+
+        self.print("    void create_domain(int argc, char **argv) {")
+        self.print.add_indent(4)
+        self.generate_statement(create_domain_module.block)
+        self.print.add_indent(-4)
+        self.print("    }")
+        self.print("")
+
+        self.print("    template<typename Domain_T>")
+        self.print("    void use_domain(std::shared_ptr<Domain_T> domain_ptr) {")
+        self.print("        pairs_runtime->useDomain(domain_ptr);")
+        self.print("    }")
+        self.print("")
+
+        self.print("    void setup_sim() {")
+        self.print.add_indent(4)
+        self.generate_statement(setup_sim_module.block)
+        self.print.add_indent(-4)
+        self.print("    }")
+        self.print("")
+
+        self.print("    void do_timestep(int timestep) {")
+        self.print("        pobj->sim_timestep = timestep;")
         self.print.add_indent(4)
         self.generate_statement(do_timestep_module.block)
         self.print.add_indent(-4)
         self.print("    }")
         self.print("")
+
         self.print("    void end() {")
         self.print("        pairs::print_timers(pairs_runtime);")
         self.print("        pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);")
@@ -305,6 +329,7 @@ class CGen:
         self.print("        delete pairs_runtime;")
         self.print("    }")
         self.print("};")
+        
         self.print.end()
         self.generate_full_object_names = False
 
@@ -319,7 +344,7 @@ class CGen:
             self.generate_full_object_names = True
             self.print("int main(int argc, char **argv) {")
             self.print(f"    PairsRuntime *pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
-            self.print(f"    struct pairs_objects *pobj = new pairs_objects();")
+            self.print(f"    struct PairsObjects *pobj = new PairsObjects();")
 
             if module.sim._enable_profiler:
                 self.print("    LIKWID_MARKER_INIT;")
@@ -331,12 +356,14 @@ class CGen:
 
             self.print("    pairs::print_timers(pairs_runtime);")
             self.print("    pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);")
+            self.print("    delete pobj;")
+            self.print("    delete pairs_runtime;")
             self.print("    return 0;")
             self.print("}")
             self.generate_full_object_names = False
 
         else:
-            self.print(f"void {module.name}(struct pairs_objects *pobj) {{")
+            self.print(f"void {module.name}(struct PairsObjects *pobj) {{")
             self.print.add_indent(4)
 
             if self.debug:
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index b0581e3..7d5a8e1 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -64,6 +64,7 @@ class Simulation:
         self.contact_properties = ContactProperties(self)
 
         # General capacities, sizes and particle properties
+        self.sim_timestep = self.add_var('sim_timestep', Types.Int32, runtime=True)
         self.particle_capacity = \
             self.add_var('particle_capacity', Types.Int32, particle_capacity, runtime=True)
         self.neighbor_capacity = self.add_var('neighbor_capacity', Types.Int32, neighbor_capacity)
@@ -90,7 +91,8 @@ class Simulation:
         self._block = Block(self, [])
 
         # Different segments of particle code/functions
-        self.setups = Block(self, [])
+        self.create_domain = Block(self, [])
+        self.setup_particles = Block(self, [])
         self.setup_functions = []
         self.pre_step_functions = []
         self.functions = []
@@ -262,26 +264,26 @@ class Simulation:
 
     def set_domain(self, grid):
         self.grid = Grid3D(self, grid[0], grid[1], grid[2], grid[3], grid[4], grid[5])
-        self.setups.add_statement(InitializeDomain(self))
+        self.create_domain.add_statement(InitializeDomain(self))
 
     def reneighbor_every(self, frequency):
         self.reneighbor_frequency = frequency
 
     def create_particle_lattice(self, grid, spacing, props={}):
-        self.setups.add_statement(ParticleLattice(self, grid, spacing, props, self.position()))
+        self.setup_particles.add_statement(ParticleLattice(self, grid, spacing, props, self.position()))
 
     def read_particle_data(self, filename, prop_names, shape_id):
         """Generate statement to read particle data from file"""
         props = [self.property(prop_name) for prop_name in prop_names]
-        self.setups.add_statement(ReadParticleData(self, filename, props, shape_id))
+        self.setup_particles.add_statement(ReadParticleData(self, filename, props, shape_id))
 
     def copper_fcc_lattice(self, nx, ny, nz, rho, temperature, ntypes):
         """Specific initialization for MD Copper FCC lattice case"""
-        self.setups.add_statement(CopperFCCLattice(self, nx, ny, nz, rho, temperature, ntypes))
+        self.setup_particles.add_statement(CopperFCCLattice(self, nx, ny, nz, rho, temperature, ntypes))
 
     def dem_sc_grid(self, xmax, ymax, zmax, spacing, diameter, min_diameter, max_diameter, initial_velocity, particle_density, ntypes):
         """Specific initialization for DEM grid"""
-        self.setups.add_statement(
+        self.setup_particles.add_statement(
             DEMSCGrid(self, xmax, ymax, zmax, spacing, diameter, min_diameter, max_diameter,
                       initial_velocity, particle_density, ntypes))
 
@@ -442,8 +444,6 @@ class Simulation:
         comm = Comm(self, self._dom_part)
         # Params that determine when a method must be called only when reneighboring
         every_reneighbor_params = {'every': self.reneighbor_frequency}
-        # Update domain is added at last on setups because particles must be already present in the simulation
-        self.setups.add_statement(UpdateDomain(self))
 
         # First steps executed during each time-step in the simulation
         timestep_procedures = self.pre_step_functions + [
@@ -509,8 +509,11 @@ class Simulation:
         # Combine everything into a whole program
         if self._generate_whole_program:
             # Initialization and setup functions, together with time-step loop
+            # UpdateDomain is added after setup_particles because particles must be already present in the simulation
             body = Block.from_list(self, [
-                self.setups,
+                self.create_domain,
+                self.setup_particles,
+                UpdateDomain(self),        
                 self.setup_functions,
                 BuildCellListsStencil(self, self.cell_lists),
                 timestep.as_block()
@@ -527,21 +530,25 @@ class Simulation:
 
         # Generate a small library to be called
         else:
-            all_setups = Block.merge_blocks(
-                inits,
-                Block.from_list(self, [
-                    self.setups,
+            initialize_module = Module(self, name='initialize', block=inits)
+            create_domain_module = Module(self, name='create_domain', block=self.create_domain)
+
+            setup_sim = Block.from_list(self, [
+                    self.setup_particles,
+                    UpdateDomain(self),
                     self.setup_functions,
                     BuildCellListsStencil(self, self.cell_lists),
-                ]))
+                ])
 
-            initialize_module = Module(self, name='initialize', block=all_setups)
+            setup_sim_module = Module(self, name='setup_sim', block=setup_sim)
             do_timestep_module = Module(self, name='do_timestep', block=timestep.as_block())
 
-            transformations = Transformations([initialize_module, do_timestep_module], self._target)
+            modules_list = [initialize_module, create_domain_module, setup_sim_module, do_timestep_module]
+
+            transformations = Transformations(modules_list, self._target)
             transformations.apply_all()
 
             # Generate library
-            self.code_gen.generate_library(initialize_module, do_timestep_module)
+            self.code_gen.generate_library(initialize_module, create_domain_module, setup_sim_module, do_timestep_module)
 
         self.code_gen.generate_interfaces()
diff --git a/src/pairs/sim/timestep.py b/src/pairs/sim/timestep.py
index 1281a4d..abef09a 100644
--- a/src/pairs/sim/timestep.py
+++ b/src/pairs/sim/timestep.py
@@ -10,7 +10,7 @@ class Timestep:
     def __init__(self, sim, nsteps, item_list=None):
         self.sim = sim
         self.block = Block(sim, [])
-        self.timestep_loop = For(sim, 0, nsteps + 1, self.block)
+        self.timestep_loop = For(sim, 0, nsteps + 1, self.block) if self.sim._generate_whole_program else None
 
         if item_list is not None:
             for item in item_list:
@@ -31,13 +31,13 @@ class Timestep:
                     self.add(item)
 
     def timestep(self):
-        return self.timestep_loop.iter()
+        return self.timestep_loop.iter() if self.sim._generate_whole_program else self.sim.sim_timestep
 
     def add(self, item, exec_every=0, item_else=None, skip_first=False):
         assert exec_every >= 0, "exec_every parameter must be higher or equal than zero!"
         stmts = item if not isinstance(item, Block) else item.statements()
         stmts_else = None
-        ts = self.timestep_loop.iter()
+        ts = self.timestep() 
         self.sim.enter(self.block)
 
         if item_else is not None:
@@ -65,7 +65,7 @@ class Timestep:
         self.sim.capture_statements(False)
 
         block = Block(self.sim, [Call_Void(self.sim, "pairs::start_timer", [Timers.All]),
-                                 self.timestep_loop,
+                                 self.timestep_loop if self.sim._generate_whole_program else self.block,
                                  Call_Void(self.sim, "pairs::stop_timer", [Timers.All])])
 
         self.sim.capture_statements(_capture)
-- 
GitLab