From 38a012c7405f17f7052b6146e0eeefdbb55206ff Mon Sep 17 00:00:00 2001 From: Behzad Safaei <iwia103h@alex2.nhr.fau.de> Date: Tue, 28 Jan 2025 12:54:54 +0100 Subject: [PATCH] Add support for custrom args in user-defined modules --- examples/dem_sd.py | 7 +- examples/main.cpp | 164 ++++++++++++++++---------- src/pairs/analysis/devices.py | 6 +- src/pairs/analysis/modules.py | 5 + src/pairs/code_gen/cgen.py | 211 +++++++++++++++++++++------------- src/pairs/ir/declaration.py | 2 +- src/pairs/ir/device.py | 3 + src/pairs/ir/functions.py | 8 ++ src/pairs/ir/kernel.py | 16 +++ src/pairs/ir/module.py | 26 ++++- src/pairs/ir/parameters.py | 18 +++ src/pairs/mapping/funcs.py | 12 +- src/pairs/sim/simulation.py | 55 +++++++-- 13 files changed, 370 insertions(+), 163 deletions(-) create mode 100644 src/pairs/ir/parameters.py diff --git a/examples/dem_sd.py b/examples/dem_sd.py index c0e0212..15ee699 100644 --- a/examples/dem_sd.py +++ b/examples/dem_sd.py @@ -165,7 +165,7 @@ psim.setup(update_mass_and_inertia, {'densityParticle_SI': densityParticle_SI, #psim.compute_half() psim.build_cell_lists(linkedCellWidth) -# psim.vtk_output(f"output/dem_{target}", frequency=visSpacing) +psim.vtk_output(f"output/dem_{target}", frequency=visSpacing) psim.compute(gravity, symbols={'densityParticle_SI': densityParticle_SI, @@ -175,11 +175,10 @@ psim.compute(gravity, psim.compute(linear_spring_dashpot, linkedCellWidth, - symbols={'dt': dt_SI, - 'pi': math.pi, + symbols={'pi': math.pi, 'kappa': kappa, 'lnDryResCoeff': lnDryResCoeff, 'collisionTime_SI': collisionTime_SI}) -psim.compute(euler, symbols={'dt': dt_SI}) +psim.compute(euler, parameters={'dt' : pairs.real()}) psim.generate() diff --git a/examples/main.cpp b/examples/main.cpp index ea9094f..0bae494 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -5,28 +5,29 @@ #include <blockforest/Initialization.h> int main(int argc, char **argv) { + double dt = 5e-5; auto pairs_sim = std::make_shared<PairsSimulation>(); auto pairs_acc = std::make_shared<PairsAccessor>(pairs_sim); // Create forest (make sure to use_domain(forest)) ---------------------------------------------- - walberla::math::AABB domain(0, 0, 0, 0.1, 0.1, 0.1); - std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance(); - mpiManager->initializeMPI(&argc, &argv); - mpiManager->useWorldComm(); - auto rank = mpiManager->rank(); - auto procs = mpiManager->numProcesses(); - auto block_config = walberla::Vector3<int>(2, 2, 2); - auto ref_level = 0; - std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest( - domain, block_config, walberla::Vector3<bool>(false, false, false), procs, ref_level); + // walberla::math::AABB domain(0, 0, 0, 0.1, 0.1, 0.1); + // std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance(); + // mpiManager->initializeMPI(&argc, &argv); + // mpiManager->useWorldComm(); + // auto procs = mpiManager->numProcesses(); + // auto block_config = walberla::Vector3<int>(2, 2, 2); + // auto ref_level = 0; + // std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest( + // domain, block_config, walberla::Vector3<bool>(false, false, false), procs, ref_level); //----------------------------------------------------------------------------------------------- // initialize pairs data structures ---------------------------------------------- pairs_sim->initialize(); // either create new domain or use an existing one ---------------------------------------- - // pairs_sim->create_domain(argc, argv); - pairs_sim->use_domain(forest); + pairs_sim->create_domain(argc, argv); + + // pairs_sim->use_domain(forest); // create planes and particles ------------------------------------------------------------ pairs_sim->create_halfspace(0, 0, 0, 1, 0, 0, 0, 13); @@ -36,23 +37,26 @@ int main(int argc, char **argv) { pairs_sim->create_halfspace(0.1, 0.1, 0.1, 0, -1, 0, 0, 13); pairs_sim->create_halfspace(0.1, 0.1, 0.1, 0, 0, -1, 0, 13); - pairs::id_t pUid = pairs_sim->create_sphere(0.0499, 0.0499, 0.07, 0.5, 0.5, 0 , 1000, 0.0045, 0, 0); - pairs::id_t pUid2 = pairs_sim->create_sphere(0.0499, 0.0499, 0.0499, 0.5, 0.5, 0 , 1000, 0.0045, 0, 0); + // pairs::id_t pUid = pairs_sim->create_sphere(0.0499, 0.0499, 0.07, 0.5, 0.5, 0 , 1000, 0.0045, 0, 0); + // pairs::id_t pUid2 = pairs_sim->create_sphere(0.0499, 0.0499, 0.0499, 0.5, 0.5, 0 , 1000, 0.0045, 0, 0); // Tracking a particle ------------------------------------------------------------------------ // if (pUid != pairs_acc->getInvalidUid()){ // std::cout<< "Particle " << pUid << " is created in rank " << rank << std::endl; // } - walberla::mpi::allReduceInplace(pUid, walberla::mpi::SUM); - walberla::mpi::allReduceInplace(pUid2, walberla::mpi::SUM); + // walberla::mpi::allReduceInplace(pUid, walberla::mpi::SUM); + // walberla::mpi::allReduceInplace(pUid2, walberla::mpi::SUM); + // MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD); + // MPI_Allreduce(MPI_IN_PLACE, &pUid2, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD); + // if (pUid != pairs_acc->getInvalidUid()){ // std::cout<< "Particle " << pUid << " will be tracked by rank " << rank << std::endl; // } - auto pIsLocalInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdx(uid) != pairs_acc->getInvalidIdx();}; - auto pIsGhostInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdxGhost(uid) != pairs_acc->getInvalidIdx();}; + // auto pIsLocalInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdx(uid) != pairs_acc->getInvalidIdx();}; + // auto pIsGhostInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdxGhost(uid) != pairs_acc->getInvalidIdx();}; // TODO: make sure linkedCellWidth is larger than max diameter in the system @@ -74,55 +78,95 @@ int main(int argc, char **argv) { // pairs_acc->setLinearVelocity(idx, walberla::Vector3<double>(0.5, 0.5, 0.5)); // } - for (int t=0; t<1; ++t){ + for (int t=0; t<5000; ++t){ // if ((t%200==0) && (pIsLocalInMyRank(pUid))){ // int idx = pairs_acc->uidToIdx(pUid); // std::cout<< "Tracked particle is now in rank " << rank << " --- " << pairs_acc->getPosition(idx)<< std::endl; // } pairs_sim->communicate(t); + + // if(pairs_sim->rank() == 0){ + // pairs_sim->pairs_runtime->copyPropertyToHost(0, ReadOnly, (((pairs_sim->pobj->nlocal + pairs_sim->pobj->nghost) * 1) * sizeof(unsigned long long int))); // uid + // for(int i=0; i<pairs_sim->pobj->nlocal; ++i){ + // std::cout << "rank ##0## local uid[" << i << "] = " << pairs_sim->pobj->uid[i] << std::endl; + // } + // } + // if(pairs_sim->rank() == 2){ + // pairs_sim->pairs_runtime->copyPropertyToHost(0, ReadOnly, (((pairs_sim->pobj->nlocal + pairs_sim->pobj->nghost) * 1) * sizeof(unsigned long long int))); // uid + // for(int i=0; i<pairs_sim->pobj->nlocal; ++i){ + // std::cout << "rank 2 local uid[" << i << "] = " << pairs_sim->pobj->uid[i] << std::endl; + // } + // for(int i=pairs_sim->pobj->nlocal; i<pairs_sim->pobj->nlocal + pairs_sim->pobj->nghost; ++i){ + // std::cout << "rank 2 ghost uid[" << i << "] = " << pairs_sim->pobj->uid[i] << std::endl; + // } + // } + + // std::cout << pairs_sim->rank() << " --------- nlocal = " << pairs_acc->nlocal() << " nghost = " << pairs_acc->nghost() << std::endl; + + // if (pairs_sim->rank() == 2) { + // for(int i=0; i<pairs_acc->nlocal(); ++i){ + // std::cout << pairs_sim->rank() << " Local " << i << " -- uid= " << pairs_acc->getUid(i) << std::endl; + // } + // for(int i=pairs_acc->nlocal(); i<pairs_acc->size(); ++i){ + // std::cout << pairs_sim->rank() << " Ghost -- " << i << " -- uid= " << pairs_acc->getUid(i) << std::endl; + // } + // } + + + + // if (pIsLocalInMyRank(pUid)){ + // int idx = pairs_acc->uidToIdx(pUid); + // pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1)); + // pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(2, 2, 2)); + // } + // if (pIsLocalInMyRank(pUid2)){ + // int idx = pairs_acc->uidToIdx(pUid2); + // pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(10, 10, 10)); + // pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(20, 20, 20)); + // } - if (pIsLocalInMyRank(pUid)){ - int idx = pairs_acc->uidToIdx(pUid); - pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1)); - pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(2, 2, 2)); - } - if (pIsLocalInMyRank(pUid2)){ - int idx = pairs_acc->uidToIdx(pUid2); - pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(10, 10, 10)); - pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(20, 20, 20)); - } - - pairs_sim->do_timestep(t); - - if (pIsGhostInMyRank(pUid)){ - int idx = pairs_acc->uidToIdxGhost(pUid); - pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1)); - pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(3, 3, 3)); - } - if (pIsGhostInMyRank(pUid2)){ - int idx = pairs_acc->uidToIdxGhost(pUid2); - pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(2, 2, 2)); - pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(4, 4, 4)); - } - - std::cout << "reverse comm and reduce" << std::endl; - pairs_sim->reverse_comm(); - - if (pIsLocalInMyRank(pUid)){ - int idx = pairs_acc->uidToIdx(pUid); - auto forceSum = pairs_acc->getHydrodynamicForce(idx); - auto torqueSum = pairs_acc->getHydrodynamicTorque(idx); - std::cout << pUid << " @@@@ reduced force = " << forceSum << std::endl; - std::cout << pUid << " @@@@ reduced torque = " << torqueSum << std::endl; - } - if (pIsLocalInMyRank(pUid2)){ - int idx = pairs_acc->uidToIdx(pUid2); - auto forceSum = pairs_acc->getHydrodynamicForce(idx); - auto torqueSum = pairs_acc->getHydrodynamicTorque(idx); - std::cout << pUid2 << " @@@@ reduced force = " << forceSum << std::endl; - std::cout << pUid2 << " @@@@ reduced torque = " << torqueSum << std::endl; - } + pairs_sim->update_cells(); + pairs_sim->gravity(); + pairs_sim->linear_spring_dashpot(); + pairs_sim->euler(dt); + + // std::cout << "reset_volatiles" << std::endl; + pairs_sim->reset_volatiles(); + + pairs_sim->vtk_write("output/dem_sd_local", 0, pairs_acc->nlocal(), t, 200); + pairs_sim->vtk_write("output/dem_sd_ghost", pairs_acc->nlocal(), pairs_acc->size(), t, 200); + + // if (pIsGhostInMyRank(pUid)){ + // std::cout << pairs_sim->rank() << " - ghost 1: " << pUid << std::endl; + // int idx = pairs_acc->uidToIdxGhost(pUid); + // pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1)); + // pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(3, 3, 3)); + // } + // if (pIsGhostInMyRank(pUid2)){ + // std::cout << pairs_sim->rank() << " - ghost 2: " << pUid2 << std::endl; + // int idx = pairs_acc->uidToIdxGhost(pUid2); + // pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(2, 2, 2)); + // pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(4, 4, 4)); + // } + + // std::cout << "reverse comm and reduce" << std::endl; + // pairs_sim->reverse_comm(); + + // if (pIsLocalInMyRank(pUid)){ + // int idx = pairs_acc->uidToIdx(pUid); + // auto forceSum = pairs_acc->getHydrodynamicForce(idx); + // auto torqueSum = pairs_acc->getHydrodynamicTorque(idx); + // std::cout << pUid << " @@@@ reduced force = " << forceSum << std::endl; + // std::cout << pUid << " @@@@ reduced torque = " << torqueSum << std::endl; + // } + // if (pIsLocalInMyRank(pUid2)){ + // int idx = pairs_acc->uidToIdx(pUid2); + // auto forceSum = pairs_acc->getHydrodynamicForce(idx); + // auto torqueSum = pairs_acc->getHydrodynamicTorque(idx); + // std::cout << pUid2 << " @@@@ reduced force = " << forceSum << std::endl; + // std::cout << pUid2 << " @@@@ reduced torque = " << torqueSum << std::endl; + // } } pairs_sim->end(); diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 23a3bed..1c009af 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -196,7 +196,11 @@ class FetchKernelReferences(Visitor): # Variables only have a device version when changed within kernels if self.writing: ast_node.device_flag = True - + + def visit_Parameter(self, ast_node): + for k in self.kernel_stack: + k.add_parameter(ast_node, self.writing) + def visit_Iter(self, ast_node): for k in self.kernel_stack: if ast_node.is_ref_candidate(): diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index 4cf0bf8..24ab2be 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -115,3 +115,8 @@ class FetchModulesReferences(Visitor): for m in self.module_stack: if not ast_node.temporary(): m.add_variable(ast_node, self.writing) + + def visit_Parameter(self, ast_node): + for m in self.module_stack: + # parameters are restricted to read-only, passed by value + m.add_parameter(ast_node, write=False) \ No newline at end of file diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index c9e2c2e..27093c5 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -28,6 +28,7 @@ from pairs.ir.sizeof import Sizeof from pairs.ir.types import Types from pairs.ir.print import Print, PrintCode from pairs.ir.variables import Var, DeclareVariable, Deref +from pairs.ir.parameters import Parameter from pairs.ir.vectors import Vector, VectorAccess, VectorOp, ZeroVector from pairs.sim.domain_partitioners import DomainPartitioners from pairs.sim.timestep import Timestep @@ -175,10 +176,19 @@ class CGen: self.print("") self.print("// Module headers") - for module in self.sim.modules(): - if module.name != "main": - self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj);") + self.print("namespace pairs::internal {") + self.print.add_indent(4) + + for module in self.sim.modules(): + if module.name != "main" and not module.user_defined: + module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}" + for param in module.parameters()) + module_params = ", " + module_params if module_params else "" + self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params});") + + self.print.add_indent(-4) + self.print("}") self.print("") def generate_host_pairs_accessor_class(self): @@ -370,7 +380,7 @@ class CGen: self.print.end() - def generate_library(self, initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module): + def generate_library(self, update_cells_module, user_defined_modules, initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module): self.generate_interfaces() # Generate CUDA/CPP file with modules ext = ".cu" if self.target.is_gpu() else ".cpp" @@ -378,6 +388,7 @@ class CGen: self.print.start() self.generate_preamble() self.print(f"#include \"{self.ref}.hpp\"") + self.print("") if self.target.is_gpu(): for array in self.sim.arrays.statics(): @@ -393,13 +404,22 @@ class CGen: tkw = Types.c_keyword(self.sim, t) size = feature_prop.array_size() self.print(f"__constant__ {tkw} d_{feature_prop.name()}[{size}];") + + self.print("") + self.print("namespace pairs::internal {") + self.print.add_indent(4) + for kernel in self.sim.kernels(): self.generate_kernel(kernel) for module in self.sim.modules(): - if module.name not in ['initialize', 'create_domain', 'setup_sim', 'do_timestep', 'reverse_comm', 'communicate', 'reset_volatiles']: - self.generate_module(module) + if module.name not in ['update_cells', 'initialize', 'create_domain', 'setup_sim', 'do_timestep', 'reverse_comm', 'communicate', 'reset_volatiles']: + if not module.user_defined: + self.generate_module(module) + + self.print.add_indent(-4) + self.print("}") self.print.end() @@ -420,7 +440,7 @@ class CGen: self.generate_full_object_names = True self.print("class PairsSimulation {") - self.print("private:") + self.print("public:") self.print(" PairsRuntime *pairs_runtime;") self.print(" struct PairsObjects *pobj;") self.print(" friend class PairsAccessor;") @@ -428,19 +448,19 @@ class CGen: self.print.add_indent(4) + for module in user_defined_modules: + self.generate_module(module) + self.print("") + self.print("void initialize() {") self.print(f" pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});") self.print(f" pobj = new PairsObjects();") - self.print.add_indent(4) self.generate_statement(initialize_module.block) - self.print.add_indent(-4) self.print("}") self.print("") self.print("void create_domain(int argc, char **argv) {") - self.print.add_indent(4) self.generate_statement(create_domain_module.block) - self.print.add_indent(-4) self.print("}") self.print("") @@ -466,39 +486,43 @@ class CGen: self.print("") self.print("void setup_sim() {") - self.print.add_indent(4) self.generate_statement(setup_sim_module.block) - self.print.add_indent(-4) self.print("}") self.print("") - self.print("void do_timestep(int timestep) {") - self.print(" pobj->sim_timestep = timestep;") - self.print.add_indent(4) - self.generate_statement(do_timestep_module.block) - self.print.add_indent(-4) + self.print("void update_cells() {") + self.generate_statement(update_cells_module.block) self.print("}") self.print("") + # self.print("void do_timestep(int timestep) {") + # self.print(" pobj->sim_timestep = timestep;") + # self.print.add_indent(4) + # self.generate_statement(do_timestep_module.block) + # self.print.add_indent(-4) + # self.print("}") + # self.print("") + + + self.print("void reverse_comm() {") - self.print.add_indent(4) self.generate_statement(reverse_comm_module.block) - self.print.add_indent(-4) self.print("}") self.print("") self.print("void communicate(int timestep) {") self.print(" pobj->sim_timestep = timestep;") - self.print.add_indent(4) self.generate_statement(communicate_module.block) - self.print.add_indent(-4) self.print("}") self.print("") self.print("void reset_volatiles() {") - self.print.add_indent(4) self.generate_statement(reset_volatiles_module.block) - self.print.add_indent(-4) + self.print("}") + self.print("") + + self.print("void vtk_write(const char* filename, int start, int end, int timestep, int frequency) {") + self.print(" pairs::vtk_write_data(pairs_runtime, filename, start, end, timestep, frequency);") self.print("}") self.print("") @@ -519,6 +543,62 @@ class CGen: self.print.end() self.generate_full_object_names = False + def generate_module_declerations(self, module): + device_cond = module.run_on_device and self.target.is_gpu() + + for var in module.read_only_variables(): + type_kw = Types.c_keyword(self.sim, var.type()) + self.print(f"{type_kw} {var.name()} = pobj->{var.name()};") + + for var in module.write_variables(): + type_kw = Types.c_keyword(self.sim, var.type()) + + if device_cond and var.device_flag: + self.print(f"{type_kw} *{var.name()} = pobj->rv_{var.name()}.getDevicePointer();") + elif var.force_read: + self.print(f"{type_kw} {var.name()} = pobj->{var.name()};") + else: + self.print(f"{type_kw} *{var.name()} = &(pobj->{var.name()});") + + for array in module.arrays(): + type_kw = Types.c_keyword(self.sim, array.type()) + name = array.name() if not device_cond else f"d_{array.name()}" + if not array.is_static() or (array.is_static() and not device_cond): + self.print(f"{type_kw} *{array.name()} = pobj->{name};") + + if array in module.host_references(): + self.print(f"{type_kw} *h_{array.name()} = pobj->{array.name()};") + + + for prop in module.properties(): + type_kw = Types.c_keyword(self.sim, prop.type()) + name = prop.name() if not device_cond else f"d_{prop.name()}" + self.print(f"{type_kw} *{prop.name()} = pobj->{name};") + + if prop in module.host_references(): + self.print(f"{type_kw} *h_{prop.name()} = pobj->{prop.name()};") + + for contact_prop in module.contact_properties(): + type_kw = Types.c_keyword(self.sim, contact_prop.type()) + name = contact_prop.name() if not device_cond else f"d_{contact_prop.name()}" + self.print(f"{type_kw} *{contact_prop.name()} = pobj->{name};") + + if contact_prop in module.host_references(): + self.print(f"{type_kw} *h_{contact_prop.name()} = pobj->{contact_prop.name()};") + + for feature_prop in module.feature_properties(): + type_kw = Types.c_keyword(self.sim, feature_prop.type()) + name = feature_prop.name() if not device_cond else f"d_{feature_prop.name()}" + + if feature_prop.device_flag and device_cond: + # self.print(f"{type_kw} *{feature_prop.name()} = {self.generate_object_reference(feature_prop, device=device_cond)};") + continue + else: + self.print(f"{type_kw} *{feature_prop.name()} = pobj->{name};") + + if feature_prop in module.host_references(): + self.print(f"{type_kw} *h_{feature_prop.name()} = pobj->{feature_prop.name()};") + def generate_module(self, module): if module.name == 'main': ndims = module.sim.ndims() @@ -549,65 +629,23 @@ class CGen: self.generate_full_object_names = False else: - self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj) {{") + module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}" + for param in module.parameters()) + if not module.user_defined: + module_params = ", " + module_params if module_params else "" + self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{") + else: + + self.print(f"void {module.name}({module_params}) {{") + + self.print.add_indent(4) - device_cond = module.run_on_device and self.target.is_gpu() if self.debug: self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");") - for var in module.read_only_variables(): - type_kw = Types.c_keyword(self.sim, var.type()) - self.print(f"{type_kw} {var.name()} = pobj->{var.name()};") - - for var in module.write_variables(): - type_kw = Types.c_keyword(self.sim, var.type()) - - if device_cond and var.device_flag: - self.print(f"{type_kw} *{var.name()} = pobj->rv_{var.name()}.getDevicePointer();") - elif var.force_read: - self.print(f"{type_kw} {var.name()} = pobj->{var.name()};") - else: - self.print(f"{type_kw} *{var.name()} = &(pobj->{var.name()});") - - for array in module.arrays(): - type_kw = Types.c_keyword(self.sim, array.type()) - name = array.name() if not device_cond else f"d_{array.name()}" - if not array.is_static() or (array.is_static() and not device_cond): - self.print(f"{type_kw} *{array.name()} = pobj->{name};") - - if array in module.host_references(): - self.print(f"{type_kw} *h_{array.name()} = pobj->{array.name()};") - - - for prop in module.properties(): - type_kw = Types.c_keyword(self.sim, prop.type()) - name = prop.name() if not device_cond else f"d_{prop.name()}" - self.print(f"{type_kw} *{prop.name()} = pobj->{name};") - - if prop in module.host_references(): - self.print(f"{type_kw} *h_{prop.name()} = pobj->{prop.name()};") - - for contact_prop in module.contact_properties(): - type_kw = Types.c_keyword(self.sim, contact_prop.type()) - name = contact_prop.name() if not device_cond else f"d_{contact_prop.name()}" - self.print(f"{type_kw} *{contact_prop.name()} = pobj->{name};") - - if contact_prop in module.host_references(): - self.print(f"{type_kw} *h_{contact_prop.name()} = pobj->{contact_prop.name()};") - - for feature_prop in module.feature_properties(): - type_kw = Types.c_keyword(self.sim, feature_prop.type()) - name = feature_prop.name() if not device_cond else f"d_{feature_prop.name()}" - - if feature_prop.device_flag and device_cond: - # self.print(f"{type_kw} *{feature_prop.name()} = {self.generate_object_reference(feature_prop, device=device_cond)};") - continue - else: - self.print(f"{type_kw} *{feature_prop.name()} = pobj->{name};") - - if feature_prop in module.host_references(): - self.print(f"{type_kw} *h_{feature_prop.name()} = pobj->{feature_prop.name()};") + if not module.user_defined: + self.generate_module_declerations(module) self.print.add_indent(-4) self.generate_statement(module.block) @@ -616,6 +654,11 @@ class CGen: def generate_kernel(self, kernel): kernel_params = "int range_start" has_resizes = False + for param in kernel.parameters(): + type_kw = Types.c_keyword(self.sim, param.type()) + decl = f"{type_kw} {param.name()}" + kernel_params += f", {decl}" + for var in kernel.read_only_variables(): type_kw = Types.c_keyword(self.sim, var.type()) decl = f"{type_kw} {var.name()}" @@ -981,6 +1024,9 @@ class CGen: kernel = ast_node.kernel kernel_params = f"{range_start}" + for param in kernel.parameters(): + kernel_params += f", {param.name()}" + for var in kernel.read_only_variables(): kernel_params += f", {var.name()}" @@ -1022,7 +1068,9 @@ class CGen: self.print("}") if isinstance(ast_node, ModuleCall): - self.print(f"{ast_node.module.name}(pairs_runtime, pobj);") + module_params = ", ".join(f"{param.name()}" for param in ast_node.module.parameters()) + module_params = ", " + module_params if module_params else "" + self.print(f"pairs::internal::{ast_node.module.name}(pairs_runtime, pobj{module_params});") if isinstance(ast_node, Print): args = ast_node.args @@ -1333,7 +1381,10 @@ class CGen: if isinstance(ast_node, Var): return self.generate_object_reference(ast_node, index=index) - + + if isinstance(ast_node, Parameter): + return ast_node.name() + if isinstance(ast_node, VectorAccess): return self.generate_expression(ast_node.expr, mem, self.generate_expression(ast_node.index)) diff --git a/src/pairs/ir/declaration.py b/src/pairs/ir/declaration.py index 3e26c57..35992d5 100644 --- a/src/pairs/ir/declaration.py +++ b/src/pairs/ir/declaration.py @@ -7,7 +7,7 @@ class Decl(ASTNode): self.elem = elem def __str__(self): - return f"Decl<self.elem>" + return f"Decl<{self.elem}>" def children(self): return [self.elem] diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index d7cb0d1..4eddecf 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -62,6 +62,9 @@ class CopyProperty(ASTNode): self._action = action self.sim.add_statement(self) + def __str__(self): + return f"CopyProperty<{self._prop}>" + def prop(self): return self._prop diff --git a/src/pairs/ir/functions.py b/src/pairs/ir/functions.py index b18406b..efaf3e6 100644 --- a/src/pairs/ir/functions.py +++ b/src/pairs/ir/functions.py @@ -11,6 +11,9 @@ class Call(ASTTerm): self.params = [Lit.cvt(sim, p) for p in params] self.return_type = return_type + def __str__(self): + return f"Call<{self.func_name}, {self.params}>" + def name(self): return self.func_name @@ -28,8 +31,13 @@ class Call_Int(Call): def __init__(self, sim, func_name, parameters): super().__init__(sim, func_name, parameters, Types.Int32) + def __str__(self): + return f"Call_Int<{self.func_name}, {self.params}>" class Call_Void(Call): def __init__(self, sim, func_name, parameters): super().__init__(sim, func_name, parameters, Types.Invalid) sim.add_statement(self) + + def __str__(self): + return f"Cal_Void<{self.func_name}, {self.params}>" diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index 5faaee4..4f477ca 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -8,6 +8,7 @@ from pairs.ir.matrices import MatrixOp from pairs.ir.properties import Property, ContactProperty from pairs.ir.quaternions import QuaternionOp from pairs.ir.variables import Var +from pairs.ir.parameters import Parameter from pairs.ir.vectors import VectorOp from pairs.ir.loops import Iter @@ -20,6 +21,7 @@ class Kernel(ASTNode): self._id = Kernel.last_kernel self._name = name if name is not None else "kernel" + str(Kernel.last_kernel) self._variables = {} + self._parameters = {} self._iters = {} self._arrays = {} self._properties = {} @@ -52,6 +54,9 @@ class Kernel(ASTNode): def variables(self): return self._variables + def parameters(self): + return self._parameters + def iters(self): return self._iters @@ -105,6 +110,17 @@ class Kernel(ASTNode): action = Actions.NoAction if var not in self._variables else self._variables[var] self._variables[var] = Actions.update_rule(action, new_op) + def add_parameter(self, parameter, write=False): + parameter_list = parameter if isinstance(parameter, list) else [parameter] + new_op = 'w' if write else 'r' + + for param in parameter_list: + assert isinstance(param, Parameter), \ + "Module.add_parameter(): given element is not of type Parameter!" + + action = Actions.NoAction if param not in self._parameters else self._parameters[param] + self._parameters[param] = Actions.update_rule(action, new_op) + def add_iter(self, iter, write=False): iter_list = iter if isinstance(iter, list) else [iter] new_op = 'w' if write else 'r' diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index ab78942..52ee117 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -4,15 +4,17 @@ from pairs.ir.ast_node import ASTNode from pairs.ir.features import FeatureProperty from pairs.ir.properties import Property, ContactProperty from pairs.ir.variables import Var +from pairs.ir.parameters import Parameter class Module(ASTNode): last_module = 0 - def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False): + def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False, user_defined=False): super().__init__(sim) self._id = Module.last_module self._name = name if name is not None else "module" + str(Module.last_module) + self._parameters = {} self._variables = {} self._arrays = {} self._properties = {} @@ -23,6 +25,7 @@ class Module(ASTNode): self._resizes_to_check = resizes_to_check self._check_properties_resize = check_properties_resize self._run_on_device = run_on_device + self._user_defined = user_defined self._profile = False sim.add_module(self) Module.last_module += 1 @@ -45,6 +48,10 @@ class Module(ASTNode): @property def run_on_device(self): return self._run_on_device + + @property + def user_defined(self): + return self._user_defined def profile(self): self._profile = True @@ -53,6 +60,9 @@ class Module(ASTNode): def must_profile(self): return self._profile + def parameters(self): + return self._parameters + def variables(self): return self._variables @@ -99,6 +109,17 @@ class Module(ASTNode): action = Actions.NoAction if var not in self._variables else self._variables[var] self._variables[var] = Actions.update_rule(action, new_op) + def add_parameter(self, parameter, write=False): + parameter_list = parameter if isinstance(parameter, list) else [parameter] + new_op = 'w' if write else 'r' + + for param in parameter_list: + assert isinstance(param, Parameter), \ + "Module.add_parameter(): given element is not of type Parameter!" + + action = Actions.NoAction if param not in self._parameters else self._parameters[param] + self._parameters[param] = Actions.update_rule(action, new_op) + def add_property(self, prop, write=False): prop_list = prop if isinstance(prop, list) else [prop] new_op = 'w' if write else 'r' @@ -150,5 +171,8 @@ class ModuleCall(ASTNode): def module(self): return self._module + def __str__(self): + return f"ModuleCall<{self._module}>" + def children(self): return [self._module] diff --git a/src/pairs/ir/parameters.py b/src/pairs/ir/parameters.py new file mode 100644 index 0000000..d6b9ab6 --- /dev/null +++ b/src/pairs/ir/parameters.py @@ -0,0 +1,18 @@ +from pairs.ir.ast_term import ASTTerm +from pairs.ir.operator_class import OperatorClass + + +class Parameter(ASTTerm): + def __init__(self, sim, param_name, param_type): + super().__init__(sim, OperatorClass.from_type(param_type)) + self.param_name = param_name + self.param_type = param_type + + def __str__(self): + return f"Parameter<{self.param_name}>" + + def name(self): + return self.param_name + + def type(self): + return self.param_type diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py index 0632bd2..efcb825 100644 --- a/src/pairs/mapping/funcs.py +++ b/src/pairs/mapping/funcs.py @@ -7,6 +7,7 @@ from pairs.ir.loops import For, ParticleFor from pairs.ir.operators import Operators from pairs.ir.operator_class import OperatorClass from pairs.ir.properties import ContactProperty +from pairs.ir.parameters import Parameter from pairs.ir.scalars import ScalarOp from pairs.ir.types import Types from pairs.mapping.keywords import Keywords @@ -80,9 +81,10 @@ class BuildParticleIR(ast.NodeVisitor): raise Exception("Invalid operator: {}".format(ast.dump(op))) - def __init__(self, sim, ctx_symbols={}): + def __init__(self, sim, ctx_symbols={}, func_params={}): self.sim = sim self.ctx_symbols = ctx_symbols.copy() + self.func_params = func_params.copy() self.keywords = Keywords(sim) def add_symbols(self, symbols): @@ -210,6 +212,7 @@ class BuildParticleIR(ast.NodeVisitor): def visit_Name(self, node): symbol_types = [ self.ctx_symbols.get, + self.func_params.get, self.sim.array, self.sim.property, self.sim.feature_property, @@ -282,7 +285,7 @@ class BuildParticleIR(ast.NodeVisitor): return op_class(self.sim, operand, None, BuildParticleIR.get_unary_op(node.op)) -def compute(sim, func, cutoff_radius=None, symbols={}, pre_step=False, skip_first=False): +def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False): src = inspect.getsource(func) tree = ast.parse(src, mode='exec') #print(ast.dump(ast.parse(src, mode='exec'))) @@ -298,6 +301,7 @@ def compute(sim, func, cutoff_radius=None, symbols={}, pre_step=False, skip_firs # Convert literal symbols symbols = {symbol: Lit.cvt(sim, value) for symbol, value in symbols.items()} + parameters = {pname: Parameter(sim, pname, ptype) for pname, ptype in parameters.items()} sim.init_block() sim.module_name(func.__name__) @@ -305,14 +309,14 @@ def compute(sim, func, cutoff_radius=None, symbols={}, pre_step=False, skip_firs if nparams == 1: for i in ParticleFor(sim): for _ in Filter(sim, ScalarOp.cmp(sim.particle_flags[i] & Flags.Fixed, 0)): - ir = BuildParticleIR(sim, symbols) + ir = BuildParticleIR(sim, symbols, parameters) ir.add_symbols({params[0]: i}) ir.visit(tree) else: for interaction_data in ParticleInteraction(sim, nparams, cutoff_radius): # Start building IR - ir = BuildParticleIR(sim, symbols) + ir = BuildParticleIR(sim, symbols, parameters) ir.add_symbols({ params[0]: interaction_data.i(), params[1]: interaction_data.j(), diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 044a135..5c665e4 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -93,11 +93,13 @@ class Simulation: # Different segments of particle code/functions self.create_domain = Block(self, []) self.setup_particles = Block(self, []) + self.module_list = [] + self.kernel_list = [] + + # User-defined modules self.setup_functions = [] self.pre_step_functions = [] self.functions = [] - self.module_list = [] - self.kernel_list = [] # Structures to generated resize code for capacities self._check_properties_resize = False @@ -303,8 +305,8 @@ class Simulation: self.neighbor_lists = NeighborLists(self, self.cell_lists) return self.neighbor_lists - def compute(self, func, cutoff_radius=None, symbols={}, pre_step=False, skip_first=False): - return compute(self, func, cutoff_radius, symbols, pre_step, skip_first) + def compute(self, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False): + return compute(self, func, cutoff_radius, symbols, parameters, pre_step, skip_first) def setup(self, func, symbols={}): return setup(self, func, symbols) @@ -447,8 +449,10 @@ class Simulation: # Params that determine when a method must be called only when reneighboring every_reneighbor_params = {'every': self.reneighbor_frequency} + timestep_procedures = [] + # First steps executed during each time-step in the simulation - timestep_procedures = self.pre_step_functions + timestep_procedures += self.pre_step_functions comm_routine = [ (comm.exchange(), every_reneighbor_params), @@ -458,21 +462,23 @@ class Simulation: if self._generate_whole_program: timestep_procedures += comm_routine - timestep_procedures += [ + update_cells = [ (BuildCellLists(self, self.cell_lists), every_reneighbor_params), (PartitionCellLists(self, self.cell_lists), every_reneighbor_params) ] # Add routine to build neighbor-lists per cell if self._store_neighbors_per_cell: - timestep_procedures.append( + update_cells.append( (BuildCellNeighborLists(self, self.cell_lists), every_reneighbor_params)) # Add routine to build neighbor-lists per particle (standard Verlet Lists) if self.neighbor_lists is not None: - timestep_procedures.append( + update_cells.append( (BuildNeighborLists(self, self.neighbor_lists), every_reneighbor_params)) + timestep_procedures += update_cells + # Add routines for contact history management if self._use_contact_history: if self.neighbor_lists is not None: @@ -548,6 +554,9 @@ class Simulation: # Generate a small library to be called else: + update_cells = [m[0] if isinstance(m, tuple) else m for m in update_cells] + update_cells_module = Module(self, name='update_cells', block=Block.from_list(self, update_cells)) + initialize_module = Module(self, name='initialize', block=inits) create_domain_module = Module(self, name='create_domain', block=self.create_domain) @@ -563,12 +572,34 @@ class Simulation: communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).as_block()) reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self))) - modules_list = [initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module] + modules_list = [ + update_cells_module, + initialize_module, + create_domain_module, + setup_sim_module, + reverse_comm_module, + communicate_module, + reset_volatiles_module + ] - transformations = Transformations(modules_list, self._target) - transformations.apply_all() + Transformations(modules_list, self._target).apply_all() + + # user defined modules are transformed seperately as indvidual modules + # i.e. they are transformed once again if already transformed in setup_sim or do_timestep + user_defined_modules = self.setup_functions + self.pre_step_functions + self.functions + user_defined_modules = [m[0] if isinstance(m, tuple) else m for m in user_defined_modules] + user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in user_defined_modules] + Transformations(user_defined_modules, self._target).apply_all() # Generate library - self.code_gen.generate_library(initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module) + self.code_gen.generate_library(update_cells_module, + user_defined_modules, + initialize_module, + create_domain_module, + setup_sim_module, + do_timestep_module, + reverse_comm_module, + communicate_module, + reset_volatiles_module) self.code_gen.generate_interfaces() -- GitLab