From 00eb515ddc25d4ae32490371c4dfc2ba4036605d Mon Sep 17 00:00:00 2001 From: Behzad Safaei <iwia103h@a0905.nhr.fau.de> Date: Fri, 14 Feb 2025 02:38:27 +0100 Subject: [PATCH] Refactor library generation --- examples/modular/force_reduction.cpp | 5 +- examples/modular/force_reduction.py | 4 +- examples/modular/sd_sample_1_CPU_GPU.cpp | 3 +- examples/modular/sd_sample_2_CPU_GPU.cpp | 3 +- examples/modular/sd_sample_3_CPU.cpp | 5 +- examples/modular/sd_sample_3_GPU.cu | 18 +- examples/modular/spring_dashpot.py | 6 +- .../spring_dashpot.py | 3 +- src/pairs/code_gen/accessor.py | 2 +- src/pairs/code_gen/cgen.py | 135 +++------- src/pairs/code_gen/interface.py | 251 ++++++++++++++++++ src/pairs/ir/block.py | 15 ++ src/pairs/ir/module.py | 12 +- src/pairs/mapping/funcs.py | 17 +- src/pairs/sim/comm.py | 188 +++++++------ src/pairs/sim/domain.py | 6 +- src/pairs/sim/simulation.py | 177 ++++++------ 17 files changed, 546 insertions(+), 304 deletions(-) create mode 100644 src/pairs/code_gen/interface.py diff --git a/examples/modular/force_reduction.cpp b/examples/modular/force_reduction.cpp index b34a345..e5d2394 100644 --- a/examples/modular/force_reduction.cpp +++ b/examples/modular/force_reduction.cpp @@ -16,6 +16,7 @@ int main(int argc, char **argv) { // setup_sim after creating all bodies pairs_sim->setup_sim(); + pairs_sim->update_mass_and_inertia(); // Track particle //------------------------------------------------------------------------------------------- @@ -31,7 +32,7 @@ int main(int argc, char **argv) { // Communicate particles (exchange/ghost) //------------------------------------------------------------------------------------------- - pairs_sim->communicate(); + pairs_sim->communicate(0); ac->update(); // Helper lambdas for demo @@ -89,7 +90,7 @@ int main(int argc, char **argv) { // Do computations //------------------------------------------------------------------------------------------- - pairs_sim->update_cells(); + pairs_sim->update_cells(t); pairs_sim->gravity(); pairs_sim->spring_dashpot(); pairs_sim->euler(5e-5); diff --git a/examples/modular/force_reduction.py b/examples/modular/force_reduction.py index 50846cd..af9cea7 100644 --- a/examples/modular/force_reduction.py +++ b/examples/modular/force_reduction.py @@ -102,7 +102,9 @@ psim.set_domain_partitioner(pairs.block_forest()) psim.pbc([False, False, False]) psim.build_cell_lists(linkedCellWidth) -psim.setup(update_mass_and_inertia, symbols={'infinity': math.inf }) +# The order of user-defined functions is not important here since +# they are not used by other subroutines and are only callable individually +psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf }) psim.compute(spring_dashpot, linkedCellWidth) psim.compute(gravity, symbols={'gravity_SI': gravity_SI }) psim.compute(euler, parameters={'dt': pairs.real()}) diff --git a/examples/modular/sd_sample_1_CPU_GPU.cpp b/examples/modular/sd_sample_1_CPU_GPU.cpp index 03b835c..960ffba 100644 --- a/examples/modular/sd_sample_1_CPU_GPU.cpp +++ b/examples/modular/sd_sample_1_CPU_GPU.cpp @@ -20,6 +20,7 @@ int main(int argc, char **argv) { pairs_sim->create_sphere(0.4, 0.4, 0.68, 2, 2, 0, 1000, 0.05, 0, 0); pairs_sim->setup_sim(); + pairs_sim->update_mass_and_inertia(); int num_timesteps = 2000; int vtk_freq = 20; @@ -30,7 +31,7 @@ int main(int argc, char **argv) { pairs_sim->communicate(t); - pairs_sim->update_cells(); + pairs_sim->update_cells(t); pairs_sim->gravity(); pairs_sim->spring_dashpot(); diff --git a/examples/modular/sd_sample_2_CPU_GPU.cpp b/examples/modular/sd_sample_2_CPU_GPU.cpp index 330003f..e18ce39 100644 --- a/examples/modular/sd_sample_2_CPU_GPU.cpp +++ b/examples/modular/sd_sample_2_CPU_GPU.cpp @@ -42,6 +42,7 @@ int main(int argc, char **argv) { pairs_sim->create_sphere(0.4, 0.4, 0.68, 2, 2, 0, 1000, 0.05, 0, 0); pairs_sim->setup_sim(); + pairs_sim->update_mass_and_inertia(); int num_timesteps = 2000; int vtk_freq = 20; @@ -52,7 +53,7 @@ int main(int argc, char **argv) { pairs_sim->communicate(t); - pairs_sim->update_cells(); + pairs_sim->update_cells(t); pairs_sim->gravity(); pairs_sim->spring_dashpot(); diff --git a/examples/modular/sd_sample_3_CPU.cpp b/examples/modular/sd_sample_3_CPU.cpp index 315a7a2..002e4f9 100644 --- a/examples/modular/sd_sample_3_CPU.cpp +++ b/examples/modular/sd_sample_3_CPU.cpp @@ -32,8 +32,9 @@ int main(int argc, char **argv) { auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();}; pairs_sim->setup_sim(); + pairs_sim->update_mass_and_inertia(); - pairs_sim->communicate(); + pairs_sim->communicate(0); int num_timesteps = 2000; int vtk_freq = 20; @@ -55,7 +56,7 @@ int main(int argc, char **argv) { // Calculate forces //------------------------------------------------------------------------------------------- - pairs_sim->update_cells(); + pairs_sim->update_cells(t); pairs_sim->gravity(); pairs_sim->spring_dashpot(); diff --git a/examples/modular/sd_sample_3_GPU.cu b/examples/modular/sd_sample_3_GPU.cu index 6861c08..e18bbf0 100644 --- a/examples/modular/sd_sample_3_GPU.cu +++ b/examples/modular/sd_sample_3_GPU.cu @@ -65,8 +65,9 @@ int main(int argc, char **argv) { auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();}; pairs_sim->setup_sim(); + pairs_sim->update_mass_and_inertia(); - pairs_sim->communicate(); + pairs_sim->communicate(0); // PairsAccessor requires an update when particles are communicated ac->update(); @@ -85,15 +86,15 @@ int main(int argc, char **argv) { int idx = ac->uidToIdxLocal(pUid); // Up-to-date position might be on host or device. - // Sync position on HostAndDevice before printing it from host and device: - ac->syncPosition(PairsAccessor::HostAndDevice); - + // Sync position on Host before reading it from host: + ac->syncPosition(PairsAccessor::Host); std::cout << "Position [from host] = (" << ac->getPosition(idx)[0] << ", " << ac->getPosition(idx)[1] << ", " << ac->getPosition(idx)[2] << ")" << std::endl; - // Position is synced on both host and device. Print position from device: + // Sync position on Device before reading it from device: + ac->syncPosition(PairsAccessor::Device); print_position<<<1,1>>>(*ac, idx); checkCudaError(cudaDeviceSynchronize(), "print_position"); @@ -102,13 +103,12 @@ int main(int argc, char **argv) { // Calculate forces //------------------------------------------------------------------------------------------- - pairs_sim->update_cells(); + pairs_sim->update_cells(t); pairs_sim->gravity(); pairs_sim->spring_dashpot(); // Change gravitational force on particle pUid //------------------------------------------------------------------------------------------- - // Here we are syncing Uid on Host again for clarity, but no data transfer will happen since Uid is already on host ac->syncUid(PairsAccessor::Host); if(pIsLocalInMyRank(pUid)){ @@ -125,8 +125,8 @@ int main(int argc, char **argv) { checkCudaError(cudaDeviceSynchronize(), "change_gravitational_force"); // Force on device was modified. - // So sync force before continuing the simulation. By default (no args), force is synced on both host and device - ac->syncForce(); + // So sync force before continuing the simulation. + ac->syncForce(PairsAccessor::Host); std::cout << "Force [from host] after changing = (" << ac->getForce(idx)[0] << ", " << ac->getForce(idx)[1] << ", " diff --git a/examples/modular/spring_dashpot.py b/examples/modular/spring_dashpot.py index 4374e78..1fb833d 100644 --- a/examples/modular/spring_dashpot.py +++ b/examples/modular/spring_dashpot.py @@ -98,9 +98,9 @@ psim.set_domain_partitioner(pairs.block_forest()) psim.pbc([False, False, False]) psim.build_cell_lists(linkedCellWidth) -psim.setup(update_mass_and_inertia, symbols={'infinity': math.inf }) - -# The order of user-defined functions is not important here since they are only callable individually +# The order of user-defined functions is not important here since +# they are not used by other subroutines and are only callable individually +psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf }) psim.compute(spring_dashpot, linkedCellWidth) psim.compute(euler, parameters={'dt': pairs.real()}) psim.compute(gravity, symbols={'gravity_SI': gravity_SI }) diff --git a/examples/whole-program-generation/spring_dashpot.py b/examples/whole-program-generation/spring_dashpot.py index f4796cf..397b542 100644 --- a/examples/whole-program-generation/spring_dashpot.py +++ b/examples/whole-program-generation/spring_dashpot.py @@ -142,9 +142,10 @@ psim.read_particle_data( "data/sd_planes.input", ['type', 'mass', 'position', 'n psim.vtk_output(f"output/dem_{target}", frequency=visSpacing) +# The user-defined setup functions are executed only once before the timestep loop psim.setup(update_mass_and_inertia, symbols={'infinity': math.inf }) -# The user-defined functions are added to the timestep loop in the order they are given to 'compute' +# The user-defined compute functions are added to the timestep loop in the order they are given to 'compute' psim.compute(spring_dashpot, linkedCellWidth) psim.compute(gravity, symbols={'gravity_SI': gravity_SI }) psim.compute(euler, symbols={'dt': dt_SI}) diff --git a/src/pairs/code_gen/accessor.py b/src/pairs/code_gen/accessor.py index 41b8bd0..eb2a73d 100644 --- a/src/pairs/code_gen/accessor.py +++ b/src/pairs/code_gen/accessor.py @@ -42,7 +42,7 @@ class PairsAcessor: self.sync_ctx_enum() self.update() self.constructor() - self.destructor() + # self.destructor() for p in self.sim.properties: if (p.type()==Types.Vector) or (Types.is_scalar(p.type())): diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 313101d..82b082f 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -149,7 +149,8 @@ class CGen: if self.target.is_gpu(): self.print("#define PAIRS_TARGET_CUDA") - + self.print("#include <math_constants.h>") + if self.target.is_openmp(): self.print("#define PAIRS_TARGET_OPENMP") self.print("#include <omp.h>") @@ -176,10 +177,14 @@ class CGen: def generate_module_header(self, module, definition=True): module_params = [] - if not module.user_defined and not module.interface: + + if not module.interface: module_params += ["PairsRuntime *pairs_runtime", "struct PairsObjects *pobj"] - if module.name=="init_domain" or module.name=="set_domain": + if module.name=="initialize" and self.sim.create_domain_at_initialization: + module_params += ["int argc", "char **argv"] + + if module.name=="set_domain": module_params += ["int argc", "char **argv"] module_params += [f"{Types.c_keyword(self.sim, param.type())} {param.name()}" for param in module.parameters()] @@ -194,7 +199,9 @@ class CGen: self.print("namespace pairs::internal {") self.print.add_indent(4) - for module in self.sim.modules(): + # All modules except the interface ones are declared in the pairs::internal scope + for module in self.sim.modules() + self.sim.udf_modules(): + assert not module.interface self.generate_module_header(module, definition=False) self.print.add_indent(-4) @@ -307,16 +314,7 @@ class CGen: self.print.end() - def generate_library(self, - update_cells_module, - user_defined_modules, - initialize_module, - set_domain_module, - setup_sim_module, - reverse_comm_module, - communicate_module, - reset_volatiles_module): - + def generate_library(self): self.generate_interfaces() # Generate CUDA/CPP file with modules ext = ".cu" if self.target.is_gpu() else ".cpp" @@ -345,14 +343,14 @@ class CGen: self.print("namespace pairs::internal {") self.print.add_indent(4) - + for kernel in self.sim.kernels(): self.generate_kernel(kernel) - for module in self.sim.modules(): - if module.name not in ['update_cells', 'initialize', 'set_domain', 'setup_sim', 'reverse_comm', 'communicate', 'reset_volatiles']: - if not module.user_defined: - self.generate_module(module) + # All modules except the interface ones are defined in the pairs::internal scope + for module in self.sim.modules() + self.sim.udf_modules(): + assert not module.interface + self.generate_module(module) self.print.add_indent(-4) self.print("}") @@ -368,106 +366,33 @@ class CGen: self.generate_pairs_object_structure() self.generate_module_decls() - ndims = module.sim.ndims() - nprops = module.sim.properties.nprops() - ncontactprops = module.sim.contact_properties.nprops() - narrays = module.sim.arrays.narrays() - part = DomainPartitioners.c_keyword(module.sim.partitioner()) - self.generate_full_object_names = True self.print("class PairsSimulation {") self.print("private:") self.print(" PairsRuntime *pairs_runtime;") self.print(" struct PairsObjects *pobj;") self.print(" friend class PairsAccessor;") + self.print("") self.print("public:") - self.print.add_indent(4) - for module in user_defined_modules: + # Only interface modules are generated in the PairsSimulation class + for module in self.sim.interface_modules(): self.generate_module(module) - self.print("") - - if set_domain_module: - self.print("void initialize() {") - else: - self.print("void initialize(int argc, char **argv) {") - self.print(f" pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});") - self.print(f" pobj = new PairsObjects();") - self.generate_statement(initialize_module.block) - self.print("}") - self.print("") - if set_domain_module: - self.generate_module(set_domain_module) + # Generate a 'use_domain' module only if domain is not predefined in the input script + if not self.sim.create_domain_at_initialization: + self.print("template<typename Domain_T>") + self.print("void use_domain(const std::shared_ptr<Domain_T> &domain_ptr) {") + self.print(" pairs_runtime->useDomain(domain_ptr);") + self.print("}") self.print("") - self.print("template<typename Domain_T>") - self.print("void use_domain(const std::shared_ptr<Domain_T> &domain_ptr) {") - self.print(" pairs_runtime->useDomain(domain_ptr);") - self.print("}") - self.print("") - - self.print("pairs::id_t create_halfspace(double x, double y, double z, double nx, double ny, double nz, int type, int flag){") - self.print(" return pairs::create_halfspace(pairs_runtime, x, y, z, nx, ny, nz, type, flag);") - self.print("}") - self.print("") - - self.print("pairs::id_t create_sphere(double x, double y, double z, double vx, double vy, double vz, double density, double radius, int type, int flag){") - self.print(" return pairs::create_sphere(pairs_runtime, x, y, z, vx, vy, vz, density, radius, type, flag);") - self.print("}") - self.print("") - - self.print("int rank(){ return pairs_runtime->getDomainPartitioner()->getRank();}") - self.print("") - - self.print("int size(){ return pobj->nlocal + pobj->nghost;}") - self.print("") - - self.print("int nlocal(){ return pobj->nlocal;}") - self.print("") - - self.print("int nghost(){ return pobj->nghost;}") - self.print("") - - self.print("void setup_sim() {") - self.generate_statement(setup_sim_module.block) - self.print("}") - self.print("") - - self.print("void update_cells() {") - self.generate_statement(update_cells_module.block) - self.print("}") - self.print("") - - self.print("void reverse_comm() {") - self.generate_statement(reverse_comm_module.block) - self.print("}") - self.print("") - - self.print("void communicate(int timestep = 0) {") - self.print(" pobj->sim_timestep = timestep;") - self.generate_statement(communicate_module.block) - self.print("}") - self.print("") - - self.print("void reset_volatiles() {") - self.generate_statement(reset_volatiles_module.block) - self.print("}") - self.print("") - self.print("void vtk_write(const char* filename, int start, int end, int timestep, int frequency) {") self.print(" pairs::vtk_write_data(pairs_runtime, filename, start, end, timestep, frequency);") self.print("}") self.print("") - self.print("void end() {") - self.print(" pairs::print_timers(pairs_runtime);") - self.print(" pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);") - self.print(" delete pobj;") - self.print(" delete pairs_runtime;") - self.print("}") - self.print.add_indent(-4) self.print("};") @@ -569,12 +494,13 @@ class CGen: if self.debug: self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");") - if not module.user_defined and not module.interface: + if not module.interface: self.generate_module_declerations(module) self.print.add_indent(-4) self.generate_statement(module.block) self.print("}") + self.print("") def generate_kernel(self, kernel): kernel_params = "int range_start" @@ -1228,7 +1154,10 @@ class CGen: return ast_node.value[index] if isinstance(ast_node.value, float) and math.isinf(ast_node.value): - return f"std::numeric_limits<{self.real_type()}>::infinity()" + if self.kernel_context: + return "CUDART_INF" + else: + return f"std::numeric_limits<{self.real_type()}>::infinity()" return ast_node.value diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py new file mode 100644 index 0000000..0e52dee --- /dev/null +++ b/src/pairs/code_gen/interface.py @@ -0,0 +1,251 @@ +from pairs.ir.block import Block, pairs_interface_block +from pairs.ir.functions import Call_Void, Call, Call_Int +from pairs.ir.parameters import Parameter +from pairs.ir.ret import Return +from pairs.ir.scalars import ScalarOp +from pairs.sim.domain import UpdateDomain, SetDomain +from pairs.sim.cell_lists import BuildCellListsStencil +from pairs.sim.comm import Synchronize, Borders, Exchange, ReverseComm +from pairs.ir.types import Types +from pairs.ir.branches import Filter, Branch +from pairs.sim.cell_lists import BuildCellLists, BuildCellListsStencil, PartitionCellLists, BuildCellNeighborLists +from pairs.sim.neighbor_lists import BuildNeighborLists +from pairs.sim.variables import DeclareVariables +from pairs.sim.arrays import DeclareArrays +from pairs.sim.properties import AllocateProperties, AllocateContactProperties, ResetVolatileProperties +from pairs.sim.features import AllocateFeatureProperties +from pairs.sim.instrumentation import RegisterMarkers, RegisterTimers +from pairs.sim.grid import MutableGrid +from pairs.sim.domain_partitioners import DomainPartitioners +from pairs.ir.print import PrintCode +from pairs.ir.assign import Assign +from pairs.sim.contact_history import BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus +from pairs.sim.thermo import ComputeThermo + +class InterfaceModules: + def __init__(self, sim): + self.sim = sim + + def create_all(self): + self.initialize() + + # Generate a 'set_domain' module only if domain is not pre-set in the input script + if not self.sim.create_domain_at_initialization: + self.set_domain() + + self.setup_sim() + self.update_cells(self.sim.reneighbor_frequency) + self.communicate(self.sim.reneighbor_frequency) + self.reverse_comm() + self.reset_volatiles() + + if self.sim._use_contact_history: + if self.neighbor_lists: + self.build_contact_history(self.sim.reneighbor_frequency) + self.reset_contact_history() + + if self.sim._compute_thermo != 0: + self.compute_thermo(self.sim._compute_thermo) + + self.rank() + self.nlocal() + self.nghost() + self.size() + self.create_sphere() + self.create_halfspace() + self.dem_sc_grid() + self.end() + + @pairs_interface_block + def initialize(self): + self.sim.module_name('initialize') + nprops = self.sim.properties.nprops() + ncontactprops = self.sim.contact_properties.nprops() + narrays = self.sim.arrays.narrays() + part = DomainPartitioners.c_keyword(self.sim.partitioner()) + + PrintCode(self.sim, f"pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});") + PrintCode(self.sim, f"pobj = new PairsObjects();") + + inits = Block.from_list(self.sim, [ + DeclareVariables(self.sim), + DeclareArrays(self.sim), + AllocateProperties(self.sim), + AllocateContactProperties(self.sim), + AllocateFeatureProperties(self.sim), + RegisterTimers(self.sim), + RegisterMarkers(self.sim) + ]) + + if self.sim.create_domain_at_initialization: + self.sim.add_statement(Block.merge_blocks(inits, self.sim.create_domain)) + else: + assert self.sim.grid is None, "A grid already exists" + self.sim.grid = MutableGrid(self.sim, self.sim.dims) + self.sim.add_statement(inits) + + @pairs_interface_block + def set_domain(self): + assert isinstance(self.sim.grid, MutableGrid) + self.sim.module_name('set_domain') + self.sim.add_statement(SetDomain(self.sim)) + + @pairs_interface_block + def setup_sim(self): + self.sim.module_name('setup_sim') + self.sim.add_statement(self.sim.setup_particles) + self.sim.add_statement(UpdateDomain(self.sim)) + self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists)) + + @pairs_interface_block + def reset_volatiles(self): + self.sim.module_name('reset_volatiles') + self.sim.add_statement(ResetVolatileProperties(self.sim)) + + @pairs_interface_block + def update_cells(self, reneighbor_frequency=1): + self.sim.module_name('update_cells') + timestep = Parameter(self.sim, f'timestep', Types.Int32) + cond = ScalarOp.inline(ScalarOp.or_op( + ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0), + ScalarOp.cmp(timestep, 0) + )) + + subroutines = [BuildCellLists(self.sim, self.sim.cell_lists), + PartitionCellLists(self.sim, self.sim.cell_lists)] + + # Add routine to build neighbor-lists per cell + if self.sim._store_neighbors_per_cell: + subroutines.append(BuildCellNeighborLists(self.sim, self.sim.cell_lists)) + + # Add routine to build neighbor-lists per particle (standard Verlet Lists) + if self.sim.neighbor_lists: + subroutines.append(BuildNeighborLists(self.sim, self.sim.neighbor_lists)) + + self.sim.add_statement(Filter(self.sim, cond, Block.from_list(self.sim, subroutines))) + + @pairs_interface_block + def communicate(self, reneighbor_frequency=1): + self.sim.module_name('communicate') + timestep = Parameter(self.sim, f'timestep', Types.Int32) + cond = ScalarOp.inline(ScalarOp.or_op( + ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0), + ScalarOp.cmp(timestep, 0) + )) + + exchange = Filter(self.sim, cond, Exchange(self.sim._comm)) + border_sync = Branch(self.sim, cond, blk_if = Borders(self.sim._comm), + blk_else = Synchronize(self.sim._comm)) + + self.sim.add_statement(exchange) + self.sim.add_statement(border_sync) + + @pairs_interface_block + def reverse_comm(self): + self.sim.module_name('reverse_comm') + self.sim.add_statement(ReverseComm(self.sim._comm, reduce=True)) + + @pairs_interface_block + def build_contact_history(self, reneighbor_frequency=1): + self.sim.module_name('build_contact_history') + timestep = Parameter(self.sim, f'timestep', Types.Int32) + cond = ScalarOp.inline(ScalarOp.or_op( + ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0), + ScalarOp.cmp(timestep, 0) + )) + + self.sim.add_statement( + Filter(self.sim, cond, + BuildContactHistory(self.sim, self.sim._contact_history, self.sim.cell_lists))) + + @pairs_interface_block + def reset_contact_history(self): + self.sim.module_name('reset_contact_history') + self.sim.add_statement(ResetContactHistoryUsageStatus(self.sim, self.sim._contact_history)) + self.sim.add_statement(ClearUnusedContactHistory(self.sim, self.sim._contact_history)) + + @pairs_interface_block + def compute_thermo(self): + self.sim.module_name('compute_thermo') + self.sim.add_statement(ComputeThermo(self.sim)) + + @pairs_interface_block + def rank(self): + self.sim.module_name('rank') + Return(self.sim, self.sim.domain_partitioning().rank) + + @pairs_interface_block + def nlocal(self): + self.sim.module_name('nlocal') + Return(self.sim, self.sim.nlocal) + + @pairs_interface_block + def nghost(self): + self.sim.module_name('nghost') + Return(self.sim, self.sim.nghost) + + @pairs_interface_block + def size(self): + self.sim.module_name('size') + Return(self.sim, ScalarOp.inline(self.sim.nlocal + self.sim.nghost)) + + @pairs_interface_block + def create_sphere(self): + self.sim.module_name('create_sphere') + x = Parameter(self.sim, 'x', Types.Real) + y = Parameter(self.sim, 'y', Types.Real) + z = Parameter(self.sim, 'z', Types.Real) + vx = Parameter(self.sim, 'vx', Types.Real) + vy = Parameter(self.sim, 'vy', Types.Real) + vz = Parameter(self.sim, 'vz', Types.Real) + density = Parameter(self.sim, 'density', Types.Real) + radius = Parameter(self.sim, 'radius', Types.Real) + ptype = Parameter(self.sim, 'type', Types.Real) + flag = Parameter(self.sim, 'flag', Types.Real) + + Return(self.sim, Call(self.sim, "pairs::create_sphere", + [x, y, z, vx, vy, vz, + density, radius, ptype, flag], Types.UInt64)) + + @pairs_interface_block + def create_halfspace(self): + self.sim.module_name('create_halfspace') + x = Parameter(self.sim, 'x', Types.Real) + y = Parameter(self.sim, 'y', Types.Real) + z = Parameter(self.sim, 'z', Types.Real) + nx = Parameter(self.sim, 'nx', Types.Real) + ny = Parameter(self.sim, 'ny', Types.Real) + nz = Parameter(self.sim, 'nz', Types.Real) + ptype = Parameter(self.sim, 'type', Types.Real) + flag = Parameter(self.sim, 'flag', Types.Real) + + Return(self.sim, Call(self.sim, "pairs::create_halfspace", + [x, y, z, nx, ny, nz, ptype, flag], Types.UInt64)) + + @pairs_interface_block + def dem_sc_grid(self): + self.sim.module_name('dem_sc_grid') + xmax = Parameter(self.sim, 'xmax', Types.Real) + ymax = Parameter(self.sim, 'ymax', Types.Real) + zmax = Parameter(self.sim, 'zmax', Types.Real) + spacing = Parameter(self.sim, 'spacing', Types.Real) + diameter = Parameter(self.sim, 'diameter', Types.Real) + min_diameter = Parameter(self.sim, 'min_diameter', Types.Real) + max_diameter = Parameter(self.sim, 'max_diameter', Types.Real) + initial_velocity = Parameter(self.sim, 'initial_velocity', Types.Real) + particle_density = Parameter(self.sim, 'particle_density', Types.Real) + ntypes = Parameter(self.sim, 'ntypes', Types.Int32) + + Assign(self.sim, self.sim.nlocal, + Call_Int(self.sim, "pairs::dem_sc_grid", + [xmax, ymax, zmax, spacing, diameter, min_diameter, max_diameter, + initial_velocity, particle_density, ntypes])) + Return(self.sim, self.sim.nlocal) + + @pairs_interface_block + def end(self): + self.sim.module_name('end') + Call_Void(self.sim, "pairs::print_timers", []) + Call_Void(self.sim, "pairs::print_stats", [self.sim.nlocal, self.sim.nghost]) + PrintCode(self.sim, "delete pobj;") + PrintCode(self.sim, "delete pairs_runtime;") diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 1a0809b..2a14ea2 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -42,6 +42,21 @@ def pairs_device_block(func): return inner +def pairs_interface_block(func): + def inner(*args, **kwargs): + sim = args[0].sim # self.sim + sim.init_block() + func(*args, **kwargs) + return Module(sim, + name=sim._module_name, + block=Block(sim, sim._block), + resizes_to_check=sim._resizes_to_check, + check_properties_resize=sim._check_properties_resize, + run_on_device=False, + interface=True) + + return inner + class Block(ASTNode): def __init__(self, sim, stmts): super().__init__(sim) diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 44539db..ded67ac 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -37,7 +37,17 @@ class Module(ASTNode): self._interface = interface self._return_type = Types.Void self._profile = False - sim.add_module(self) + + if user_defined: + assert not interface, ("User-defined modules can't be part of the interface directly." + "Wrap them inside seperate interface modules.") + sim.add_udf_module(self) + else: + if interface: + sim.add_interface_module(self) + else: + sim.add_module(self) + Module.last_module += 1 def __str__(self): diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py index 16abbab..b4ec2f8 100644 --- a/src/pairs/mapping/funcs.py +++ b/src/pairs/mapping/funcs.py @@ -335,12 +335,13 @@ def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=F ir.visit(tree) - if pre_step: - sim.build_pre_step_module_with_statements(skip_first=skip_first, profile=True) - + if sim._generate_whole_program: + if pre_step: + sim.build_pre_step_module_with_statements(skip_first=skip_first, profile=True) + else: + sim.build_module_with_statements(skip_first=skip_first, profile=True) else: - sim.build_module_with_statements(skip_first=skip_first, profile=True) - + sim.build_user_defined_function() def setup(sim, func, symbols={}): src = inspect.getsource(func) @@ -366,4 +367,8 @@ def setup(sim, func, symbols={}): ir.add_symbols({params[0]: i}) ir.visit(tree) - sim.build_setup_module_with_statements() + if sim._generate_whole_program: + sim.build_setup_module_with_statements() + else: + sim.build_user_defined_function() + diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index 0508163..6ec0179 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -49,47 +49,31 @@ class Comm: self.recv_offsets_reverse = sim.add_array('recv_offsets_reverse', [dom_part.nranks_capacity], Types.Int32) self.recv_buffer_reverse = sim.add_array('recv_buffer_reverse', [self.recv_capacity, self.elem_capacity], Types.Real, arr_sync=False) + +class Synchronize(Lowerable): + def __init__(self, comm): + self.sim = comm.sim + self.comm = comm + @pairs_inline - def synchronize(self): + def lower(self): # Every property that is not constant across timesteps and have neighbor accesses during any # interaction kernel (i.e. property[j] in force calculation kernel) prop_names = ['position', 'linear_velocity', 'angular_velocity'] prop_list = [self.sim.property(p) for p in prop_names if self.sim.property(p) is not None] - PackAllGhostParticles(self, prop_list) - CommunicateAllData(self, prop_list) - UnpackAllGhostParticles(self, prop_list) - - @pairs_host_block - def reverse_comm(self, reduce=False): - self.sim.module_name(f"reverse_comm") - prop_list = self.sim.properties.reduction_props() + PackAllGhostParticles(self.comm, prop_list) + CommunicateAllData(self.comm, prop_list) + UnpackAllGhostParticles(self.comm, prop_list) - if prop_list : - for step in range(self.dom_part.number_of_steps() - 1, -1, -1): - if self.sim._target.is_gpu(): - CopyArray(self.sim, self.nsend, Contexts.Host, Actions.ReadOnly) - CopyArray(self.sim, self.nrecv, Contexts.Host, Actions.ReadOnly) - CopyArray(self.sim, self.send_offsets, Contexts.Host, Actions.ReadOnly) - CopyArray(self.sim, self.recv_offsets, Contexts.Host, Actions.ReadOnly) - - CopyArray(self.sim, self.nsend_reverse, Contexts.Host, Actions.WriteOnly) - CopyArray(self.sim, self.nrecv_reverse, Contexts.Host, Actions.WriteOnly) - CopyArray(self.sim, self.send_offsets_reverse, Contexts.Host, Actions.WriteOnly) - CopyArray(self.sim, self.recv_offsets_reverse, Contexts.Host, Actions.WriteOnly) - - for j in self.dom_part.step_indexes(step): - Assign(self.sim, self.nsend_reverse[j], self.nrecv[j]) - Assign(self.sim, self.nrecv_reverse[j], self.nsend[j]) - Assign(self.sim, self.send_offsets_reverse[j], self.recv_offsets[j]) - Assign(self.sim, self.recv_offsets_reverse[j], self.send_offsets[j]) - PackGhostParticlesReverse(self, step, prop_list) - CommunicateDataReverse(self, step, prop_list) - UnpackGhostParticlesReverse(self, step, prop_list, reduce) +class Borders(Lowerable): + def __init__(self, comm): + self.sim = comm.sim + self.comm = comm @pairs_inline - def borders(self): + def lower(self): # Every property that has neighbor accesses during any interaction kernel (i.e. property[j] # exists in any force calculation kernel) # We ignore normal because there should be no ghost half-spaces @@ -106,84 +90,128 @@ class Comm: prop_list = [self.sim.property(p) for p in prop_names if self.sim.property(p) is not None] - Assign(self.sim, self.nsend_all, 0) + Assign(self.sim, self.comm.nsend_all, 0) Assign(self.sim, self.sim.nghost, 0) - for step in range(self.dom_part.number_of_steps()): + for step in range(self.comm.dom_part.number_of_steps()): if self.sim._target.is_gpu(): - CopyArray(self.sim, self.nsend, Contexts.Host, Actions.Ignore) - CopyArray(self.sim, self.nrecv, Contexts.Host, Actions.Ignore) + CopyArray(self.sim, self.comm.nsend, Contexts.Host, Actions.Ignore) + CopyArray(self.sim, self.comm.nrecv, Contexts.Host, Actions.Ignore) - for j in self.dom_part.step_indexes(step): - Assign(self.sim, self.nsend[j], 0) - Assign(self.sim, self.nrecv[j], 0) + for j in self.comm.dom_part.step_indexes(step): + Assign(self.sim, self.comm.nsend[j], 0) + Assign(self.sim, self.comm.nrecv[j], 0) if self.sim._target.is_gpu(): - CopyArray(self.sim, self.nsend, Contexts.Device, Actions.Ignore) - CopyArray(self.sim, self.nrecv, Contexts.Device, Actions.Ignore) + CopyArray(self.sim, self.comm.nsend, Contexts.Device, Actions.Ignore) + CopyArray(self.sim, self.comm.nrecv, Contexts.Device, Actions.Ignore) - DetermineGhostParticles(self, step, self.sim.cell_spacing()) - CommunicateSizes(self, step) - SetCommunicationOffsets(self, step) - PackGhostParticles(self, step, prop_list) - CommunicateData(self, step, prop_list) - UnpackGhostParticles(self, step, prop_list) + DetermineGhostParticles(self.comm, step, self.sim.cell_spacing()) + CommunicateSizes(self.comm, step) + SetCommunicationOffsets(self.comm, step) + PackGhostParticles(self.comm, step, prop_list) + CommunicateData(self.comm, step, prop_list) + UnpackGhostParticles(self.comm, step, prop_list) - step_nrecv = self.dom_part.reduce_sum_step_indexes(step, self.nrecv) + step_nrecv = self.comm.dom_part.reduce_sum_step_indexes(step, self.comm.nrecv) Assign(self.sim, self.sim.nghost, self.sim.nghost + step_nrecv) + +class Exchange(Lowerable): + def __init__(self, comm): + self.sim = comm.sim + self.comm = comm + @pairs_inline - def exchange(self): + def lower(self): # Every property except volatiles prop_list = self.sim.properties.non_volatiles() - for step in range(self.dom_part.number_of_steps()): - Assign(self.sim, self.nsend_all, 0) - Assign(self.sim, self.sim.nghost, 0) + for step in range(self.comm.dom_part.number_of_steps()): + Assign(self.comm.sim, self.comm.nsend_all, 0) + Assign(self.comm.sim, self.sim.nghost, 0) for s in range(step + 1): - for j in self.dom_part.step_indexes(s): - Assign(self.sim, self.nsend[j], 0) - Assign(self.sim, self.nrecv[j], 0) - Assign(self.sim, self.send_offsets[j], 0) - Assign(self.sim, self.recv_offsets[j], 0) - Assign(self.sim, self.nsend_contact[j], 0) - Assign(self.sim, self.nrecv_contact[j], 0) - Assign(self.sim, self.contact_soffsets[j], 0) - Assign(self.sim, self.contact_soffsets[j], 0) + for j in self.comm.dom_part.step_indexes(s): + Assign(self.comm.sim, self.comm.nsend[j], 0) + Assign(self.comm.sim, self.comm.nrecv[j], 0) + Assign(self.comm.sim, self.comm.send_offsets[j], 0) + Assign(self.comm.sim, self.comm.recv_offsets[j], 0) + Assign(self.comm.sim, self.comm.nsend_contact[j], 0) + Assign(self.comm.sim, self.comm.nrecv_contact[j], 0) + Assign(self.comm.sim, self.comm.contact_soffsets[j], 0) + Assign(self.comm.sim, self.comm.contact_soffsets[j], 0) if self.sim._target.is_gpu(): - CopyArray(self.sim, self.nsend, Contexts.Device, Actions.Ignore) - CopyArray(self.sim, self.nrecv, Contexts.Device, Actions.Ignore) + CopyArray(self.comm.sim, self.comm.nsend, Contexts.Device, Actions.Ignore) + CopyArray(self.comm.sim, self.comm.nrecv, Contexts.Device, Actions.Ignore) - DetermineGhostParticles(self, step, 0.0) - CommunicateSizes(self, step) - SetCommunicationOffsets(self, step) - PackGhostParticles(self, step, prop_list) + DetermineGhostParticles(self.comm, step, 0.0) + CommunicateSizes(self.comm, step) + SetCommunicationOffsets(self.comm, step) + PackGhostParticles(self.comm, step, prop_list) if self.sim._target.is_gpu(): - send_map_size = self.nsend_all * Sizeof(self.sim, Types.Int32) - exchg_flag_size = self.sim.nlocal * Sizeof(self.sim, Types.Int32) - CopyArray(self.sim, self.send_map, Contexts.Host, Actions.ReadOnly, send_map_size) - CopyArray(self.sim, self.exchg_flag, Contexts.Host, Actions.ReadOnly, exchg_flag_size) + send_map_size = self.comm.nsend_all * Sizeof(self.comm.sim, Types.Int32) + exchg_flag_size = self.sim.nlocal * Sizeof(self.comm.sim, Types.Int32) + CopyArray(self.comm.sim, self.comm.send_map, Contexts.Host, Actions.ReadOnly, send_map_size) + CopyArray(self.comm.sim, self.comm.exchg_flag, Contexts.Host, Actions.ReadOnly, exchg_flag_size) - RemoveExchangedParticles_part1(self) + RemoveExchangedParticles_part1(self.comm) if self.sim._target.is_gpu(): - exchg_copy_to_size = self.nsend_all * Sizeof(self.sim, Types.Int32) + exchg_copy_to_size = self.comm.nsend_all * Sizeof(self.comm.sim, Types.Int32) CopyArray( - self.sim, self.exchg_copy_to, Contexts.Device, Actions.ReadOnly, exchg_copy_to_size) + self.comm.sim, self.comm.exchg_copy_to, Contexts.Device, Actions.ReadOnly, exchg_copy_to_size) - RemoveExchangedParticles_part2(self, prop_list) - CommunicateData(self, step, prop_list) - UnpackGhostParticles(self, step, prop_list) + RemoveExchangedParticles_part2(self.comm, prop_list) + CommunicateData(self.comm, step, prop_list) + UnpackGhostParticles(self.comm, step, prop_list) if self.sim._use_contact_history: - PackContactHistoryData(self, step) - CommunicateContactHistoryData(self, step) - UnpackContactHistoryData(self, step) + PackContactHistoryData(self.comm, step) + CommunicateContactHistoryData(self.comm, step) + UnpackContactHistoryData(self.comm, step) + + ChangeSizeAfterExchange(self.comm, step) + + +class ReverseComm(Lowerable): + def __init__(self, comm, reduce=False): + self.sim = comm.sim + self.comm = comm + self.reduce = reduce + + @pairs_inline + def lower(self): + prop_list = self.sim.properties.reduction_props() + + if prop_list : + for step in range(self.comm.dom_part.number_of_steps() - 1, -1, -1): + if self.sim._target.is_gpu(): + CopyArray(self.sim, self.comm.nsend, Contexts.Host, Actions.ReadOnly) + CopyArray(self.sim, self.comm.nrecv, Contexts.Host, Actions.ReadOnly) + CopyArray(self.sim, self.comm.send_offsets, Contexts.Host, Actions.ReadOnly) + CopyArray(self.sim, self.comm.recv_offsets, Contexts.Host, Actions.ReadOnly) + + CopyArray(self.sim, self.comm.nsend_reverse, Contexts.Host, Actions.WriteOnly) + CopyArray(self.sim, self.comm.nrecv_reverse, Contexts.Host, Actions.WriteOnly) + CopyArray(self.sim, self.comm.send_offsets_reverse, Contexts.Host, Actions.WriteOnly) + CopyArray(self.sim, self.comm.recv_offsets_reverse, Contexts.Host, Actions.WriteOnly) + + for j in self.comm.dom_part.step_indexes(step): + Assign(self.sim, self.comm.nsend_reverse[j], self.comm.nrecv[j]) + Assign(self.sim, self.comm.nrecv_reverse[j], self.comm.nsend[j]) + Assign(self.sim, self.comm.send_offsets_reverse[j], self.comm.recv_offsets[j]) + Assign(self.sim, self.comm.recv_offsets_reverse[j], self.comm.send_offsets[j]) + + PackGhostParticlesReverse(self.comm, step, prop_list) + CommunicateDataReverse(self.comm, step, prop_list) + UnpackGhostParticlesReverse(self.comm, step, prop_list, self.reduce) + + + - ChangeSizeAfterExchange(self, step) class CommunicateSizes(Lowerable): diff --git a/src/pairs/sim/domain.py b/src/pairs/sim/domain.py index 5593840..9a85792 100644 --- a/src/pairs/sim/domain.py +++ b/src/pairs/sim/domain.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_inline, pairs_host_block +from pairs.ir.block import pairs_inline from pairs.ir.parameters import Parameter from pairs.ir.types import Types from pairs.ir.assign import Assign @@ -17,10 +17,8 @@ class SetDomain(Lowerable): def __init__(self, sim): super().__init__(sim) - @pairs_host_block + @pairs_inline def lower(self): - self.sim.module_name('init_domain') - for d in range(self.sim.ndims()): dmin = Parameter(self.sim, f'd{d}_min', Types.Real) Assign(self.sim, self.sim.grid.min(d), dmin) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 89cfb59..f58082e 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -13,7 +13,7 @@ from pairs.ir.variables import Variables from pairs.mapping.funcs import compute, setup from pairs.sim.arrays import DeclareArrays from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists, BuildCellNeighborLists -from pairs.sim.comm import Comm +from pairs.sim.comm import Comm, Synchronize, Borders, Exchange, ReverseComm from pairs.sim.contact_history import ContactHistory, BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus from pairs.sim.copper_fcc_lattice import CopperFCCLattice from pairs.sim.dem_sc_grid import DEMSCGrid @@ -32,6 +32,7 @@ from pairs.sim.timestep import Timestep from pairs.sim.variables import DeclareVariables from pairs.sim.vtk import VTKWrite from pairs.transformations import Transformations +from pairs.code_gen.interface import InterfaceModules class Simulation: @@ -92,13 +93,17 @@ class Simulation: # Different segments of particle code/functions self.create_domain = Block(self, []) - self.generate_set_domain_module = True + self.create_domain_at_initialization = False self.setup_particles = Block(self, []) self.module_list = [] self.kernel_list = [] - # User-defined modules + # Individual user-defined and interface modules are created only when generate_whole_program is False + self.udf_module_list = [] + self.interface_module_list = [] + + # User-defined functions to be called by other subroutines (used only when generate_whole_program is True) self.setup_functions = [] self.pre_step_functions = [] self.functions = [] @@ -114,6 +119,7 @@ class Simulation: # Domain partitioning self._dom_part = None self._partitioner = None + self._comm = None # Contact history self._use_contact_history = use_contact_history @@ -164,11 +170,30 @@ class Simulation: def max_shapes(self): return len(self._shapes) + def add_udf_module(self, module): + assert isinstance(module, Module), "add_udf_module(): Given parameter is not of type Module!" + assert module.user_defined and not module.interface + if module.name not in [m.name for m in self.udf_module_list]: + self.udf_module_list.append(module) + + def add_interface_module(self, module): + assert isinstance(module, Module), "add_interface_module(): Given parameter is not of type Module!" + assert module.interface and not module.user_defined + if module.name not in [m.name for m in self.interface_module_list]: + self.interface_module_list.append(module) + def add_module(self, module): assert isinstance(module, Module), "add_module(): Given parameter is not of type Module!" + assert not module.interface and not module.user_defined if module.name not in [m.name for m in self.module_list]: self.module_list.append(module) + def interface_modules(self): + return self.interface_module_list + + def udf_modules(self): + return self.udf_module_list + def modules(self): """List simulation modules, with main always in the last position""" @@ -279,7 +304,7 @@ class Simulation: If the domain is set through this function, the 'set_domain' module won't be generated in the modular version. Use this function only if you do not need to set domain at runtime. This function is required only for whole-program generation.""" - self.generate_set_domain_module = False + self.create_domain_at_initialization = True self.grid = Grid3D(self, grid[0], grid[1], grid[2], grid[3], grid[4], grid[5]) self.create_domain.add_statement(InitializeDomain(self)) @@ -357,7 +382,7 @@ class Simulation: block=Block(self, self._block), resizes_to_check=self._resizes_to_check, check_properties_resize=self._check_properties_resize, - run_on_device=False)) + run_on_device=True)) def build_pre_step_module_with_statements(self, run_on_device=True, skip_first=False, profile=False): """Build a Module in the pre-step part of the program using the last initialized block""" @@ -392,6 +417,16 @@ class Simulation: else: self.functions.append(module) + def build_user_defined_function(self, run_on_device=True): + """Build a user-defined Module that will be callable seperately as part of the interface""" + Module(self, name=self._module_name, + block=Block(self, self._block), + resizes_to_check=self._resizes_to_check, + check_properties_resize=self._check_properties_resize, + run_on_device=run_on_device, + user_defined=True) + + def capture_statements(self, capture=True): """When toggled, all constructed statements are captured and automatically added to the last initialized block""" self._capture_statements = capture @@ -454,12 +489,37 @@ class Simulation: def generate(self): """Generate the code for the simulation""" - assert self._target is not None, "Target not specified!" - # Initialize communication instance with specified domain-partitioner - comm = Comm(self, self._dom_part) - reverse_comm_module = comm.reverse_comm(reduce=True) + # Initialize communication instance with the specified domain-partitioner + self._comm = Comm(self, self._dom_part) + + if self._generate_whole_program: + self.generate_program() + else: + self.generate_library() + + def generate_library(self): + InterfaceModules(self).create_all() + + # User defined functions are wrapped inside seperate interface modules here. + # The udf's have the same name as their interface module but they get implemented in the pairs::internal scope. + for m in self.udf_module_list: + module = Module(self, name=m.name, block=Block(self, m), interface=True) + module._id = m._id + + Transformations(self.interface_modules(), self._target).apply_all() + + # Generate library + self.code_gen.generate_library() + + # Generate getters for the runtime functions + self.code_gen.generate_interfaces() + + def generate_program(self): + assert self.grid, "No domain is created. Set domain bounds with 'set_domain'." + + reverse_comm_module = ReverseComm(self._comm, reduce=True) # Params that determine when a method must be called only when reneighboring every_reneighbor_params = {'every': self.reneighbor_frequency} @@ -470,8 +530,8 @@ class Simulation: timestep_procedures += self.pre_step_functions comm_routine = [ - (comm.exchange(), every_reneighbor_params), - (comm.borders(), comm.synchronize(), every_reneighbor_params) + (Exchange(self._comm), every_reneighbor_params), + (Borders(self._comm), Synchronize(self._comm), every_reneighbor_params) ] if self._generate_whole_program: @@ -546,86 +606,25 @@ class Simulation: self.leave() # Combine everything into a whole program - if self._generate_whole_program: - assert self.grid, "No domain is created. Set domain bounds with 'set_domain'." - - # Initialization and setup functions, together with time-step loop - # UpdateDomain is added after setup_particles because particles must be already present in the simulation - body = Block.from_list(self, [ - self.create_domain, - self.setup_particles, - UpdateDomain(self), - self.setup_functions, - BuildCellListsStencil(self, self.cell_lists), - timestep.as_block() - ]) - - program = Module(self, name='main', block=Block.merge_blocks(inits, body)) - - # Apply transformations - transformations = Transformations(program, self._target) - transformations.apply_all() - - # Generate program - self.code_gen.generate_program(program) - - # Generate a small library to be called - else: - update_cells = [m[0] if isinstance(m, tuple) else m for m in update_cells] - update_cells_module = Module(self, name='update_cells', block=Block.from_list(self, update_cells)) - - # Either generate a set_domain module, or create domain during initialization - if self.generate_set_domain_module: - self.grid = MutableGrid(self, self.dims) - initialize_module = Module(self, name='initialize', block=inits) - set_domain_module = Module(self, name='set_domain', block=Block(self, SetDomain(self)), interface=True) - else: - initialize_module = Module(self, name='initialize', block=Block.merge_blocks(inits, self.create_domain)) - set_domain_module = None - - setup_sim = Block.from_list(self, [ - self.setup_particles, - UpdateDomain(self), - self.setup_functions, - BuildCellListsStencil(self, self.cell_lists), - ]) - - setup_sim_module = Module(self, name='setup_sim', block=setup_sim) - communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).block) - reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self))) - - modules_list = [ - update_cells_module, - initialize_module, - setup_sim_module, - reverse_comm_module, - communicate_module, - reset_volatiles_module - ] - - if self.generate_set_domain_module: - modules_list += [set_domain_module] - - Transformations(modules_list, self._target).apply_all() + # Initialization and setup functions, together with time-step loop + # UpdateDomain is added after setup_particles because particles must be already present in the simulation + body = Block.from_list(self, [ + self.create_domain, + self.setup_particles, + UpdateDomain(self), + self.setup_functions, + BuildCellListsStencil(self, self.cell_lists), + timestep.as_block() + ]) - # user defined modules are transformed seperately as indvidual modules - # i.e. they are transformed once again if already transformed in setup_sim or do_timestep - udf_internal = self.setup_functions + self.pre_step_functions + self.functions - udf_internal = [m[0] if isinstance(m, tuple) else m for m in udf_internal] - user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in udf_internal] - for i, m in enumerate(user_defined_modules): - m._id = udf_internal[i]._id + program = Module(self, name='main', block=Block.merge_blocks(inits, body)) - Transformations(user_defined_modules, self._target).apply_all() + # Apply transformations + transformations = Transformations(program, self._target) + transformations.apply_all() - # Generate library - self.code_gen.generate_library(update_cells_module, - user_defined_modules, - initialize_module, - set_domain_module, - setup_sim_module, - reverse_comm_module, - communicate_module, - reset_volatiles_module) + # Generate whole program + self.code_gen.generate_program(program) + # Generate getters for the runtime functions self.code_gen.generate_interfaces() -- GitLab