diff --git a/examples/modular/sd_1_CPU_GPU.cpp b/examples/modular/sd_1_CPU_GPU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..898d3e9826493af37b4c29da9a382322e219173e --- /dev/null +++ b/examples/modular/sd_1_CPU_GPU.cpp @@ -0,0 +1,47 @@ +#include <iostream> +#include <memory> + +#include "spring_dashpot.hpp" + +int main(int argc, char **argv) { + + auto pairs_sim = std::make_shared<PairsSimulation>(); + pairs_sim->initialize(); + + pairs_sim->set_domain(argc, argv, 0, 0, 0, 1, 1, 1); + + pairs_sim->create_halfspace(0,0,0, 1, 0, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 1, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 0, 1, 0, 13); + pairs_sim->create_halfspace(1,1,1, -1, 0, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, -1, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, 0, -1, 0, 13); + pairs_sim->create_sphere(0.6, 0.6, 0.7, -2, -2, 0, 1000, 0.05, 0, 0); + pairs_sim->create_sphere(0.4, 0.4, 0.68, 2, 2, 0, 1000, 0.05, 0, 0); + + pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1); + pairs_sim->update_mass_and_inertia(); + + int num_timesteps = 2000; + int vtk_freq = 20; + double dt = 1e-3; + + for (int t=0; t<num_timesteps; ++t){ + if ((t%500==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl; + + pairs_sim->communicate(t); + + pairs_sim->update_cells(t); + + pairs_sim->gravity(); + pairs_sim->spring_dashpot(); + pairs_sim->euler(dt); + + pairs_sim->reset_volatiles(); + + pairs_sim->vtk_write("output/dem_sd_local", 0, pairs_sim->nlocal(), t, vtk_freq); + pairs_sim->vtk_write("output/dem_sd_ghost", pairs_sim->nlocal(), pairs_sim->size(), t, vtk_freq); + } + + pairs_sim->end(); +} diff --git a/examples/modular/sd_2_CPU_GPU.cpp b/examples/modular/sd_2_CPU_GPU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..73b86385cccd5063fd068b8f2df1cbb2c69bd04e --- /dev/null +++ b/examples/modular/sd_2_CPU_GPU.cpp @@ -0,0 +1,69 @@ +#include <iostream> +#include <memory> + +#include <blockforest/BlockForest.h> +#include <blockforest/Initialization.h> + +#include "spring_dashpot.hpp" + +int main(int argc, char **argv) { + + auto pairs_sim = std::make_shared<PairsSimulation>(); + pairs_sim->initialize(); + + // Create forest + // ------------------------------------------------------------------------------- + walberla::math::AABB domain(0, 0, 0, 1, 1, 1); + std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance(); + mpiManager->initializeMPI(&argc, &argv); + mpiManager->useWorldComm(); + auto procs = mpiManager->numProcesses(); + + walberla::Vector3<int> block_config; + if (procs==1) block_config = walberla::Vector3<int>(1, 1, 1); + else if (procs==4) block_config = walberla::Vector3<int>(2, 2, 1); + else { std::cout << "Error: Check block_config" << std::endl; exit(-1);} + + auto ref_level = 0; + std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest( + domain, block_config, walberla::Vector3<bool>(false, false, false), procs, ref_level); + + // Pass forest to P4IRS + // ------------------------------------------------------------------------------- + pairs_sim->use_domain(forest); + + pairs_sim->create_halfspace(0,0,0, 1, 0, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 1, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 0, 1, 0, 13); + pairs_sim->create_halfspace(1,1,1, -1, 0, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, -1, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, 0, -1, 0, 13); + pairs_sim->create_sphere(0.6, 0.6, 0.7, -2, -2, 0, 1000, 0.05, 0, 0); + pairs_sim->create_sphere(0.4, 0.4, 0.68, 2, 2, 0, 1000, 0.05, 0, 0); + + pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1); + pairs_sim->update_mass_and_inertia(); + + int num_timesteps = 2000; + int vtk_freq = 20; + double dt = 1e-3; + + for (int t=0; t<num_timesteps; ++t){ + if ((t%500==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl; + + pairs_sim->communicate(t); + + pairs_sim->update_cells(t); + + pairs_sim->gravity(); + pairs_sim->spring_dashpot(); + pairs_sim->euler(dt); + + pairs_sim->reset_volatiles(); + + pairs_sim->vtk_write("output/dem_sd_local", 0, pairs_sim->nlocal(), t, vtk_freq); + pairs_sim->vtk_write("output/dem_sd_ghost", pairs_sim->nlocal(), pairs_sim->size(), t, vtk_freq); + } + + pairs_sim->end(); +} diff --git a/examples/modular/sd_3_CPU.cpp b/examples/modular/sd_3_CPU.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7dc10829d64ea658b76ffc5e95f251a8d43550c5 --- /dev/null +++ b/examples/modular/sd_3_CPU.cpp @@ -0,0 +1,95 @@ +#include <iostream> +#include <memory> + +#include "spring_dashpot.hpp" + +void change_gravitational_force(std::shared_ptr<PairsAccessor> &ac, int idx){ + pairs::Vector3<double> upward_gravity(0.0, 0.0, 2 * ac->getMass(idx) * 9.81); + ac->setForce(idx, ac->getForce(idx) + upward_gravity); +} + +int main(int argc, char **argv) { + + auto pairs_sim = std::make_shared<PairsSimulation>(); + pairs_sim->initialize(); + + auto ac = std::make_shared<PairsAccessor>(pairs_sim.get()); + + pairs_sim->set_domain(argc, argv, 0, 0, 0, 1, 1, 1); + + pairs_sim->create_halfspace(0,0,0, 1, 0, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 1, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 0, 1, 0, 13); + pairs_sim->create_halfspace(1,1,1, -1, 0, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, -1, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, 0, -1, 0, 13); + + pairs::id_t pUid = pairs_sim->create_sphere(0.6, 0.6, 0.7, 0, 0, 0, 1000, 0.05, 0, 0); + pairs_sim->create_sphere(0.4, 0.4, 0.76, 2, 2, 0, 1000, 0.05, 0, 0); + + MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD); + + auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();}; + + pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1); + pairs_sim->update_mass_and_inertia(); + + pairs_sim->communicate(0); + + int num_timesteps = 2000; + int vtk_freq = 20; + double dt = 1e-3; + + for (int t=0; t<num_timesteps; ++t){ + + // Print position of particle pUid + //------------------------------------------------------------------------------------------- + if(pIsLocalInMyRank(pUid)){ + std::cout << "Timestep (" << t << "): Particle " << pUid << " is in rank " << pairs_sim->rank() << std::endl; + int idx = ac->uidToIdxLocal(pUid); + std::cout << "Position = (" + << ac->getPosition(idx)[0] << ", " + << ac->getPosition(idx)[1] << ", " + << ac->getPosition(idx)[2] << ")" << std::endl; + + } + + // Calculate forces + //------------------------------------------------------------------------------------------- + pairs_sim->update_cells(t); + pairs_sim->gravity(); + pairs_sim->spring_dashpot(); + + // Change gravitational force on particle pUid + //------------------------------------------------------------------------------------------- + if(pIsLocalInMyRank(pUid)){ + int idx = ac->uidToIdxLocal(pUid); + + std::cout << "Force before changing = (" + << ac->getForce(idx)[0] << ", " + << ac->getForce(idx)[1] << ", " + << ac->getForce(idx)[2] << ")" << std::endl; + + change_gravitational_force(ac, idx); + + std::cout << "Force after changing = (" + << ac->getForce(idx)[0] << ", " + << ac->getForce(idx)[1] << ", " + << ac->getForce(idx)[2] << ")" << std::endl; + } + + // Euler + //------------------------------------------------------------------------------------------- + pairs_sim->euler(dt); + pairs_sim->reset_volatiles(); + + // Communicate + //------------------------------------------------------------------------------------------- + pairs_sim->communicate(t); + + pairs_sim->vtk_write("output/dem_sd_local", 0, ac->nlocal(), t, vtk_freq); + pairs_sim->vtk_write("output/dem_sd_ghost", ac->nlocal(), ac->size(), t, vtk_freq); + } + + pairs_sim->end(); +} \ No newline at end of file diff --git a/examples/modular/sd_3_GPU.cu b/examples/modular/sd_3_GPU.cu new file mode 100644 index 0000000000000000000000000000000000000000..4abdeedcdc844a6f94975bca0aedf1586df6c5e8 --- /dev/null +++ b/examples/modular/sd_3_GPU.cu @@ -0,0 +1,153 @@ +#include <iostream> +#include <memory> +#include <cuda_runtime.h> + +#include "spring_dashpot.hpp" + +void checkCudaError(cudaError_t err, const char* func) { + if (err != cudaSuccess) { + fprintf(stderr, "CUDA error in %s: %s\n", func, cudaGetErrorString(err)); + exit(err); + } +} + +__global__ void print_position(PairsAccessor ac, int idx){ + printf("Position [from device] = (%f, %f, %f) \n", ac.getPosition(idx)[0], ac.getPosition(idx)[1], ac.getPosition(idx)[2]); +} + +__global__ void change_gravitational_force(PairsAccessor ac, int idx){ + printf("Force [from device] before setting = (%f, %f, %f) \n", ac.getForce(idx)[0], ac.getForce(idx)[1], ac.getForce(idx)[2]); + + pairs::Vector3<double> upward_gravity(0.0, 0.0, 2 * ac.getMass(idx) * 9.81); + ac.setForce(idx, ac.getForce(idx) + upward_gravity); + + printf("Force [from device] after setting = (%f, %f, %f) \n", ac.getForce(idx)[0], ac.getForce(idx)[1], ac.getForce(idx)[2]); +} + +void set_feature_properties(std::shared_ptr<PairsAccessor> &ac){ + ac->setTypeStiffness(0,0, 0); + ac->setTypeStiffness(0,1, 1000); + ac->setTypeStiffness(1,0, 1000); + ac->setTypeStiffness(1,1, 3000); + ac->syncTypeStiffness(); + + ac->setTypeDampingNorm(0,0, 0); + ac->setTypeDampingNorm(0,1, 20); + ac->setTypeDampingNorm(1,0, 20); + ac->setTypeDampingNorm(1,1, 10); + ac->syncTypeDampingNorm(); +} + +int main(int argc, char **argv) { + + auto pairs_sim = std::make_shared<PairsSimulation>(); + pairs_sim->initialize(); + + // Create PairsAccessor after PairsSimulation is initialized + auto ac = std::make_shared<PairsAccessor>(pairs_sim.get()); + + pairs_sim->set_domain(argc, argv, 0, 0, 0, 1, 1, 1); + + pairs_sim->create_halfspace(0,0,0, 1, 0, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 1, 0, 0, 13); + pairs_sim->create_halfspace(0,0,0, 0, 0, 1, 0, 13); + pairs_sim->create_halfspace(1,1,1, -1, 0, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, -1, 0, 0, 13); + pairs_sim->create_halfspace(1,1,1, 0, 0, -1, 0, 13); + + pairs::id_t pUid = pairs_sim->create_sphere(0.6, 0.6, 0.7, 0, 0, 0, 1000, 0.05, 1, 0); + pairs_sim->create_sphere(0.4, 0.4, 0.76, 2, 2, 0, 1000, 0.05, 1, 0); + + set_feature_properties(ac); + + MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD); + + auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();}; + + pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1); + pairs_sim->update_mass_and_inertia(); + + pairs_sim->communicate(0); + // PairsAccessor requires an update when particles are communicated + ac->update(); + + int num_timesteps = 2000; + int vtk_freq = 20; + double dt = 1e-3; + pairs_sim->vtk_write_subdom("output/subdom", 0, 1); + + for (int t=0; t<num_timesteps; ++t){ + // Up-to-date uids might be on host or device. So sync uid in Host before accessing them from host + ac->syncUid(PairsAccessor::Host); + + // Print position of particle pUid + //------------------------------------------------------------------------------------------- + if(pIsLocalInMyRank(pUid)){ + std::cout << "Timestep (" << t << "): Particle " << pUid << " is in rank " << pairs_sim->rank() << std::endl; + int idx = ac->uidToIdxLocal(pUid); + + // Up-to-date position might be on host or device. + // Sync position on Host before reading it from host: + ac->syncPosition(PairsAccessor::Host); + std::cout << "Position [from host] = (" + << ac->getPosition(idx)[0] << ", " + << ac->getPosition(idx)[1] << ", " + << ac->getPosition(idx)[2] << ")" << std::endl; + + // Sync position on Device before reading it from device: + ac->syncPosition(PairsAccessor::Device); + print_position<<<1,1>>>(*ac, idx); + checkCudaError(cudaDeviceSynchronize(), "print_position"); + + // There's no need to sync position here to continue the simulation, since position wasn't modified. + } + + // Calculate forces + //------------------------------------------------------------------------------------------- + pairs_sim->update_cells(t); + pairs_sim->gravity(); + pairs_sim->spring_dashpot(); + + // Change gravitational force on particle pUid + //------------------------------------------------------------------------------------------- + ac->syncUid(PairsAccessor::Host); + + if(pIsLocalInMyRank(pUid)){ + std::cout << "Force Timestep (" << t << "): Particle " << pUid << " is in rank " << pairs_sim->rank() << std::endl; + int idx = ac->uidToIdxLocal(pUid); + + // Up-to-date force and mass might be on host or device. + // So sync them in Device before accessing them on device. (No data will be transfered if they are already on device) + ac->syncForce(PairsAccessor::Device); + ac->syncMass(PairsAccessor::Device); + + // Modify force from device: + change_gravitational_force<<<1,1>>>(*ac, idx); + checkCudaError(cudaDeviceSynchronize(), "change_gravitational_force"); + + // Force on device was modified. + // So sync force before continuing the simulation. + ac->syncForce(PairsAccessor::Host); + std::cout << "Force [from host] after changing = (" + << ac->getForce(idx)[0] << ", " + << ac->getForce(idx)[1] << ", " + << ac->getForce(idx)[2] << ")" << std::endl; + } + + // Euler + //------------------------------------------------------------------------------------------- + pairs_sim->euler(dt); + pairs_sim->reset_volatiles(); + + // Communicate + //------------------------------------------------------------------------------------------- + pairs_sim->communicate(t); + // PairsAccessor requires an update when particles are communicated + ac->update(); + + pairs_sim->vtk_write("output/dem_sd_local", 0, ac->nlocal(), t, vtk_freq); + pairs_sim->vtk_write("output/dem_sd_ghost", ac->nlocal(), ac->size(), t, vtk_freq); + } + + pairs_sim->end(); +} \ No newline at end of file diff --git a/examples/modular/spring_dashpot.py b/examples/modular/spring_dashpot.py index 1fb833da2032b2ccf08f200a037d2f9ee8dcbc96..191c000ca61962af8ebae77ab4ec1b97b433d399 100644 --- a/examples/modular/spring_dashpot.py +++ b/examples/modular/spring_dashpot.py @@ -71,11 +71,6 @@ elif target == 'cpu': else: print(f"Invalid target, use {sys.argv[0]} <cpu/gpu>") -gravity_SI = 9.81 -diameter = 100 # required for linkedCellWidth. TODO: set linkedCellWidth at runtime -linkedCellWidth = 1.01 * diameter -ntypes = 2 - psim.add_position('position') psim.add_property('mass', pairs.real()) psim.add_property('linear_velocity', pairs.vector()) @@ -88,6 +83,7 @@ psim.add_property('inv_inertia', pairs.matrix()) psim.add_property('rotation_matrix', pairs.matrix()) psim.add_property('rotation', pairs.quaternion()) +ntypes = 2 psim.add_feature('type', ntypes) psim.add_feature_property('type', 'stiffness', pairs.real(), [3000 for i in range(ntypes * ntypes)]) psim.add_feature_property('type', 'damping_norm', pairs.real(), [10.0 for i in range(ntypes * ntypes)]) @@ -96,13 +92,15 @@ psim.add_feature_property('type', 'friction', pairs.real()) psim.set_domain_partitioner(pairs.block_forest()) psim.pbc([False, False, False]) -psim.build_cell_lists(linkedCellWidth) +psim.build_cell_lists() # The order of user-defined functions is not important here since # they are not used by other subroutines and are only callable individually psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf }) -psim.compute(spring_dashpot, linkedCellWidth) +psim.compute(spring_dashpot) psim.compute(euler, parameters={'dt': pairs.real()}) + +gravity_SI = 9.81 psim.compute(gravity, symbols={'gravity_SI': gravity_SI }) psim.generate() diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py index 0e52deed7d6386fe4cd054188e4868773a014e6c..6255ff4424679726c424952e837bce38a193ac97 100644 --- a/src/pairs/code_gen/interface.py +++ b/src/pairs/code_gen/interface.py @@ -93,6 +93,14 @@ class InterfaceModules: @pairs_interface_block def setup_sim(self): self.sim.module_name('setup_sim') + + if self.sim.cell_lists.runtime_spacing: + for d in range(self.sim.dims): + Assign(self.sim, self.sim.cell_lists.spacing[d], Parameter(self.sim, f'cell_spacing_d{d}', Types.Real)) + + if self.sim.cell_lists.runtime_cutoff_radius: + Assign(self.sim, self.sim.cell_lists.cutoff_radius, Parameter(self.sim, 'cutoff_radius', Types.Real)) + self.sim.add_statement(self.sim.setup_particles) self.sim.add_statement(UpdateDomain(self.sim)) self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists)) diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py index e85aa0678f7fbfc6372665fdfab78f346eb21e97..a6a156a4986a6a8b122dc70d8ca920ad18d29269 100644 --- a/src/pairs/ir/math.py +++ b/src/pairs/ir/math.py @@ -1,6 +1,7 @@ from pairs.ir.ast_term import ASTTerm from pairs.ir.scalars import ScalarOp from pairs.ir.types import Types +from pairs.ir.lit import Lit class MathFunction(ASTTerm): @@ -115,6 +116,7 @@ class Cos(MathFunction): class Ceil(MathFunction): def __init__(self, sim, expr): + expr = Lit.cvt(sim, expr) assert Types.is_real(expr.type()), "Expression must be of real type!" super().__init__(sim) self._params = [expr] diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py index e8b2d2ead9930eba21d7dea41eedf6bff00441ce..4038f48fb628f6c00c55c5cf535b43e8ce1a9af3 100644 --- a/src/pairs/sim/cell_lists.py +++ b/src/pairs/sim/cell_lists.py @@ -17,15 +17,31 @@ from pairs.sim.lowerable import Lowerable class CellLists: - def __init__(self, sim, dom_part, spacing, cutoff_radius): + def __init__(self, sim, dom_part, spacing=None, cutoff_radius=None): self.sim = sim self.dom_part = dom_part - self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())] - self.cutoff_radius = cutoff_radius - self.nneighbor_cells = [math.ceil(cutoff_radius / self.spacing[d]) for d in range(sim.ndims())] - self.nstencil_max = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(sim.ndims())]) + + # Cell spacing and cutoff radius can be set at runtime + # only if they haven't been pre-set in the input script + if spacing: + self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())] + self.runtime_spacing = False + else: + assert self.sim._generate_whole_program==False, "Cell spacing needs to be defined when generating whole program." + self.spacing = self.sim.add_array('spacing', self.sim.ndims(), Types.Real) + self.runtime_spacing = True + + if cutoff_radius: + self.cutoff_radius = cutoff_radius + self.runtime_cutoff_radius = False + else: + assert self.sim._generate_whole_program==False, "cutoff_radius needs to be defined when generating whole program." + self.cutoff_radius = self.sim.add_var('cutoff_radius', Types.Real) + self.runtime_cutoff_radius = True + # Data introduced in the simulation self.nstencil = self.sim.add_var('nstencil', Types.Int32) + self.nstencil_capacity = self.sim.add_var('nstencil_capacity', Types.Int32, 27) self.ncells = self.sim.add_var('ncells', Types.Int32, 1) self.ncells_capacity = self.sim.add_var('ncells_capacity', Types.Int32, 100000) self.cell_capacity = self.sim.add_var('cell_capacity', Types.Int32, 64) @@ -34,7 +50,7 @@ class CellLists: self.cell_particles = self.sim.add_array('cell_particles', [self.ncells_capacity, self.cell_capacity], Types.Int32) self.cell_sizes = self.sim.add_array('cell_sizes', self.ncells_capacity, Types.Int32) self.nshapes = self.sim.add_array('nshapes', [self.ncells_capacity, self.sim.max_shapes()], Types.Int32) - self.stencil = self.sim.add_array('stencil', self.nstencil_max, Types.Int32) + self.stencil = self.sim.add_array('stencil', self.nstencil_capacity, Types.Int32) self.particle_cell = self.sim.add_array('particle_cell', self.sim.particle_capacity, Types.Int32) if sim._store_neighbors_per_cell: @@ -52,8 +68,9 @@ class BuildCellListsStencil(Lowerable): def lower(self): stencil = self.cell_lists.stencil nstencil = self.cell_lists.nstencil + nstencil_capacity = self.cell_lists.nstencil_capacity spacing = self.cell_lists.spacing - nneighbor_cells = self.cell_lists.nneighbor_cells + cutoff_radius = self.cell_lists.cutoff_radius dim_ncells = self.cell_lists.dim_ncells ncells = self.cell_lists.ncells ncells_capacity = self.cell_lists.ncells_capacity @@ -63,6 +80,7 @@ class BuildCellListsStencil(Lowerable): self.sim.module_name("build_cell_lists_stencil") self.sim.check_resize(ncells_capacity, ncells) + self.sim.check_resize(nstencil_capacity, nstencil) for s in range(self.sim.max_shapes()): Assign(self.sim, shapes_buffer[s], self.sim.get_shape_id(s)) @@ -79,7 +97,7 @@ class BuildCellListsStencil(Lowerable): Assign(self.sim, nstencil, 0) for dim in range(self.sim.ndims()): - nneigh = nneighbor_cells[dim] + nneigh = Ceil(self.sim,(cutoff_radius / spacing[dim])) for dim_offset in For(self.sim, -nneigh, nneigh + 1): index = dim_offset if index is None else index * dim_ncells[dim] + dim_offset if dim == self.sim.ndims() - 1: diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index f58082ea8d592351f4833d82594b0423b2906401..2c0e194981872d1195a467421519cc624dc6a8d5 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -329,14 +329,16 @@ class Simulation: DEMSCGrid(self, xmax, ymax, zmax, spacing, diameter, min_diameter, max_diameter, initial_velocity, particle_density, ntypes)) - def build_cell_lists(self, spacing, store_neighbors_per_cell=False): - """Add routines to build the linked-cells acceleration structure""" + def build_cell_lists(self, spacing=None, store_neighbors_per_cell=False): + """Add routines to build the linked-cells acceleration structure. + Leave spacing as None so it can be set at runtime.""" self._store_neighbors_per_cell = store_neighbors_per_cell self.cell_lists = CellLists(self, self._dom_part, spacing, spacing) return self.cell_lists - def build_neighbor_lists(self, spacing): - """Add routines to build the Verlet Lists acceleration structure""" + def build_neighbor_lists(self, spacing=None): + """Add routines to build the Verlet Lists acceleration structure. + Leave spacing as None so it can be set at runtime.""" assert self._store_neighbors_per_cell is False, \ "Using neighbor-lists with store_neighbors_per_cell option is invalid."