From a631855f6cae506b0a1fe9b48089cd21a90fa5c7 Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@a0522.nhr.fau.de>
Date: Tue, 18 Feb 2025 23:01:27 +0100
Subject: [PATCH] Add option for setting cell-spacing and cutoff-radius at
 runtime

---
 examples/modular/sd_1_CPU_GPU.cpp  |  47 +++++++++
 examples/modular/sd_2_CPU_GPU.cpp  |  69 +++++++++++++
 examples/modular/sd_3_CPU.cpp      |  95 ++++++++++++++++++
 examples/modular/sd_3_GPU.cu       | 153 +++++++++++++++++++++++++++++
 examples/modular/spring_dashpot.py |  12 +--
 src/pairs/code_gen/interface.py    |   8 ++
 src/pairs/ir/math.py               |   2 +
 src/pairs/sim/cell_lists.py        |  34 +++++--
 src/pairs/sim/simulation.py        |  10 +-
 9 files changed, 411 insertions(+), 19 deletions(-)
 create mode 100644 examples/modular/sd_1_CPU_GPU.cpp
 create mode 100644 examples/modular/sd_2_CPU_GPU.cpp
 create mode 100644 examples/modular/sd_3_CPU.cpp
 create mode 100644 examples/modular/sd_3_GPU.cu

diff --git a/examples/modular/sd_1_CPU_GPU.cpp b/examples/modular/sd_1_CPU_GPU.cpp
new file mode 100644
index 0000000..898d3e9
--- /dev/null
+++ b/examples/modular/sd_1_CPU_GPU.cpp
@@ -0,0 +1,47 @@
+#include <iostream>
+#include <memory>
+
+#include "spring_dashpot.hpp"
+
+int main(int argc, char **argv) {
+
+    auto pairs_sim = std::make_shared<PairsSimulation>();
+    pairs_sim->initialize();
+
+    pairs_sim->set_domain(argc, argv, 0, 0, 0, 1, 1, 1);
+
+    pairs_sim->create_halfspace(0,0,0,  1, 0, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 1, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 0, 1,     0, 13);
+    pairs_sim->create_halfspace(1,1,1,  -1, 0, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, -1, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, 0, -1,    0, 13);
+    pairs_sim->create_sphere(0.6, 0.6, 0.7,      -2, -2, 0,  1000, 0.05, 0, 0);
+    pairs_sim->create_sphere(0.4, 0.4, 0.68,    2, 2, 0,    1000, 0.05, 0, 0);
+
+    pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1);
+    pairs_sim->update_mass_and_inertia();
+
+    int num_timesteps = 2000;
+    int vtk_freq = 20;
+    double dt = 1e-3;
+    
+    for (int t=0; t<num_timesteps; ++t){
+        if ((t%500==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
+
+        pairs_sim->communicate(t);
+        
+        pairs_sim->update_cells(t); 
+
+        pairs_sim->gravity(); 
+        pairs_sim->spring_dashpot(); 
+        pairs_sim->euler(dt); 
+
+        pairs_sim->reset_volatiles(); 
+
+        pairs_sim->vtk_write("output/dem_sd_local", 0, pairs_sim->nlocal(), t, vtk_freq);
+        pairs_sim->vtk_write("output/dem_sd_ghost", pairs_sim->nlocal(), pairs_sim->size(), t, vtk_freq);
+    }
+
+    pairs_sim->end();
+}
diff --git a/examples/modular/sd_2_CPU_GPU.cpp b/examples/modular/sd_2_CPU_GPU.cpp
new file mode 100644
index 0000000..73b8638
--- /dev/null
+++ b/examples/modular/sd_2_CPU_GPU.cpp
@@ -0,0 +1,69 @@
+#include <iostream>
+#include <memory>
+
+#include <blockforest/BlockForest.h>
+#include <blockforest/Initialization.h>
+
+#include "spring_dashpot.hpp"
+
+int main(int argc, char **argv) {
+
+    auto pairs_sim = std::make_shared<PairsSimulation>();
+    pairs_sim->initialize();
+
+    // Create forest
+    // -------------------------------------------------------------------------------
+    walberla::math::AABB domain(0, 0, 0, 1, 1, 1);
+    std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance();
+    mpiManager->initializeMPI(&argc, &argv);
+    mpiManager->useWorldComm();
+    auto procs = mpiManager->numProcesses();
+
+    walberla::Vector3<int> block_config;
+    if (procs==1)        block_config = walberla::Vector3<int>(1, 1, 1);
+    else if (procs==4)   block_config = walberla::Vector3<int>(2, 2, 1);
+    else { std::cout << "Error: Check block_config" << std::endl; exit(-1);} 
+
+    auto ref_level = 0;
+    std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest(
+            domain, block_config, walberla::Vector3<bool>(false, false, false), procs, ref_level);
+
+    // Pass forest to P4IRS
+    // -------------------------------------------------------------------------------
+    pairs_sim->use_domain(forest);
+
+    pairs_sim->create_halfspace(0,0,0,  1, 0, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 1, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 0, 1,     0, 13);
+    pairs_sim->create_halfspace(1,1,1,  -1, 0, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, -1, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, 0, -1,    0, 13);
+    pairs_sim->create_sphere(0.6, 0.6, 0.7,      -2, -2, 0,  1000, 0.05, 0, 0);
+    pairs_sim->create_sphere(0.4, 0.4, 0.68,    2, 2, 0,    1000, 0.05, 0, 0);
+
+    pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1);
+    pairs_sim->update_mass_and_inertia();
+
+    int num_timesteps = 2000;
+    int vtk_freq = 20;
+    double dt = 1e-3;
+
+    for (int t=0; t<num_timesteps; ++t){
+        if ((t%500==0) && pairs_sim->rank()==0) std::cout << "Timestep: " << t << std::endl;
+
+        pairs_sim->communicate(t);
+        
+        pairs_sim->update_cells(t); 
+
+        pairs_sim->gravity(); 
+        pairs_sim->spring_dashpot(); 
+        pairs_sim->euler(dt); 
+
+        pairs_sim->reset_volatiles(); 
+
+        pairs_sim->vtk_write("output/dem_sd_local", 0, pairs_sim->nlocal(), t, vtk_freq);
+        pairs_sim->vtk_write("output/dem_sd_ghost", pairs_sim->nlocal(), pairs_sim->size(), t, vtk_freq);
+    }
+
+    pairs_sim->end();
+}
diff --git a/examples/modular/sd_3_CPU.cpp b/examples/modular/sd_3_CPU.cpp
new file mode 100644
index 0000000..7dc1082
--- /dev/null
+++ b/examples/modular/sd_3_CPU.cpp
@@ -0,0 +1,95 @@
+#include <iostream>
+#include <memory>
+
+#include "spring_dashpot.hpp"
+
+void change_gravitational_force(std::shared_ptr<PairsAccessor> &ac, int idx){
+    pairs::Vector3<double> upward_gravity(0.0, 0.0, 2 * ac->getMass(idx) * 9.81); 
+    ac->setForce(idx, ac->getForce(idx) + upward_gravity);
+}
+
+int main(int argc, char **argv) {
+
+    auto pairs_sim = std::make_shared<PairsSimulation>();
+    pairs_sim->initialize();
+
+    auto ac = std::make_shared<PairsAccessor>(pairs_sim.get());
+
+    pairs_sim->set_domain(argc, argv, 0, 0, 0, 1, 1, 1);
+
+    pairs_sim->create_halfspace(0,0,0,  1, 0, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 1, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 0, 1,     0, 13);
+    pairs_sim->create_halfspace(1,1,1,  -1, 0, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, -1, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, 0, -1,    0, 13);
+
+    pairs::id_t pUid = pairs_sim->create_sphere(0.6, 0.6, 0.7,      0, 0, 0,  1000, 0.05, 0, 0);
+    pairs_sim->create_sphere(0.4, 0.4, 0.76,    2, 2, 0,    1000, 0.05, 0, 0);
+
+    MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
+
+    auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();};
+
+    pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1);
+    pairs_sim->update_mass_and_inertia();
+
+    pairs_sim->communicate(0);
+
+    int num_timesteps = 2000;
+    int vtk_freq = 20;
+    double dt = 1e-3;
+
+    for (int t=0; t<num_timesteps; ++t){
+
+        // Print position of particle pUid
+        //-------------------------------------------------------------------------------------------
+        if(pIsLocalInMyRank(pUid)){
+            std::cout << "Timestep (" << t << "): Particle " << pUid << " is in rank " << pairs_sim->rank() << std::endl;
+            int idx = ac->uidToIdxLocal(pUid);
+            std::cout << "Position = (" 
+                    << ac->getPosition(idx)[0] << ", "
+                    << ac->getPosition(idx)[1] << ", " 
+                    << ac->getPosition(idx)[2] << ")" << std::endl;
+
+        }
+
+        // Calculate forces
+        //-------------------------------------------------------------------------------------------
+        pairs_sim->update_cells(t);
+        pairs_sim->gravity(); 
+        pairs_sim->spring_dashpot(); 
+
+        // Change gravitational force on particle pUid
+        //-------------------------------------------------------------------------------------------
+        if(pIsLocalInMyRank(pUid)){
+            int idx = ac->uidToIdxLocal(pUid);
+
+            std::cout << "Force before changing = (" 
+                    << ac->getForce(idx)[0] << ", "
+                    << ac->getForce(idx)[1] << ", " 
+                    << ac->getForce(idx)[2] << ")" << std::endl;
+
+            change_gravitational_force(ac, idx);
+
+            std::cout << "Force after changing = (" 
+                    << ac->getForce(idx)[0] << ", "
+                    << ac->getForce(idx)[1] << ", " 
+                    << ac->getForce(idx)[2] << ")" << std::endl;
+        }
+
+        // Euler
+        //-------------------------------------------------------------------------------------------
+        pairs_sim->euler(dt);
+        pairs_sim->reset_volatiles(); 
+
+        // Communicate
+        //-------------------------------------------------------------------------------------------
+        pairs_sim->communicate(t);
+
+        pairs_sim->vtk_write("output/dem_sd_local", 0, ac->nlocal(), t, vtk_freq);
+        pairs_sim->vtk_write("output/dem_sd_ghost", ac->nlocal(), ac->size(), t, vtk_freq);
+    }
+
+    pairs_sim->end();
+}
\ No newline at end of file
diff --git a/examples/modular/sd_3_GPU.cu b/examples/modular/sd_3_GPU.cu
new file mode 100644
index 0000000..4abdeed
--- /dev/null
+++ b/examples/modular/sd_3_GPU.cu
@@ -0,0 +1,153 @@
+#include <iostream>
+#include <memory>
+#include <cuda_runtime.h>
+
+#include "spring_dashpot.hpp"
+
+void checkCudaError(cudaError_t err, const char* func) {
+    if (err != cudaSuccess) {
+        fprintf(stderr, "CUDA error in %s: %s\n", func, cudaGetErrorString(err));
+        exit(err);
+    }
+}
+
+__global__ void print_position(PairsAccessor ac, int idx){
+    printf("Position [from device] = (%f, %f, %f) \n", ac.getPosition(idx)[0], ac.getPosition(idx)[1], ac.getPosition(idx)[2]);
+}
+
+__global__ void change_gravitational_force(PairsAccessor ac, int idx){
+    printf("Force [from device] before setting = (%f, %f, %f) \n", ac.getForce(idx)[0], ac.getForce(idx)[1], ac.getForce(idx)[2]);
+
+    pairs::Vector3<double> upward_gravity(0.0, 0.0, 2 * ac.getMass(idx) * 9.81); 
+    ac.setForce(idx, ac.getForce(idx) + upward_gravity);
+
+    printf("Force [from device] after setting = (%f, %f, %f) \n", ac.getForce(idx)[0], ac.getForce(idx)[1], ac.getForce(idx)[2]);
+}
+
+void set_feature_properties(std::shared_ptr<PairsAccessor> &ac){
+    ac->setTypeStiffness(0,0, 0);
+    ac->setTypeStiffness(0,1, 1000);
+    ac->setTypeStiffness(1,0, 1000);
+    ac->setTypeStiffness(1,1, 3000);
+    ac->syncTypeStiffness();
+
+    ac->setTypeDampingNorm(0,0, 0);
+    ac->setTypeDampingNorm(0,1, 20);
+    ac->setTypeDampingNorm(1,0, 20);
+    ac->setTypeDampingNorm(1,1, 10);
+    ac->syncTypeDampingNorm();
+}
+
+int main(int argc, char **argv) {
+
+    auto pairs_sim = std::make_shared<PairsSimulation>();
+    pairs_sim->initialize();
+
+    // Create PairsAccessor after PairsSimulation is initialized
+    auto ac = std::make_shared<PairsAccessor>(pairs_sim.get());
+
+    pairs_sim->set_domain(argc, argv, 0, 0, 0, 1, 1, 1);
+
+    pairs_sim->create_halfspace(0,0,0,  1, 0, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 1, 0,     0, 13);
+    pairs_sim->create_halfspace(0,0,0,  0, 0, 1,     0, 13);
+    pairs_sim->create_halfspace(1,1,1,  -1, 0, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, -1, 0,    0, 13);
+    pairs_sim->create_halfspace(1,1,1,  0, 0, -1,    0, 13);
+
+    pairs::id_t pUid = pairs_sim->create_sphere(0.6, 0.6, 0.7,      0, 0, 0,  1000, 0.05, 1, 0);
+    pairs_sim->create_sphere(0.4, 0.4, 0.76,    2, 2, 0,    1000, 0.05, 1, 0);
+
+    set_feature_properties(ac);
+
+    MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
+
+    auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();};
+
+    pairs_sim->setup_sim(0.1, 0.1, 0.1, 0.1);
+    pairs_sim->update_mass_and_inertia();
+
+    pairs_sim->communicate(0);
+    // PairsAccessor requires an update when particles are communicated 
+    ac->update();
+
+    int num_timesteps = 2000;
+    int vtk_freq = 20;
+    double dt = 1e-3;
+    pairs_sim->vtk_write_subdom("output/subdom", 0, 1);
+
+    for (int t=0; t<num_timesteps; ++t){
+        // Up-to-date uids might be on host or device. So sync uid in Host before accessing them from host
+        ac->syncUid(PairsAccessor::Host);
+
+        // Print position of particle pUid
+        //-------------------------------------------------------------------------------------------
+        if(pIsLocalInMyRank(pUid)){
+            std::cout << "Timestep (" << t << "): Particle " << pUid << " is in rank " << pairs_sim->rank() << std::endl;
+            int idx = ac->uidToIdxLocal(pUid);
+
+            // Up-to-date position might be on host or device. 
+            // Sync position on Host before reading it from host:
+            ac->syncPosition(PairsAccessor::Host); 
+            std::cout << "Position [from host] = (" 
+                    << ac->getPosition(idx)[0] << ", "
+                    << ac->getPosition(idx)[1] << ", " 
+                    << ac->getPosition(idx)[2] << ")" << std::endl;
+            
+            // Sync position on Device before reading it from device:
+            ac->syncPosition(PairsAccessor::Device); 
+            print_position<<<1,1>>>(*ac, idx);
+            checkCudaError(cudaDeviceSynchronize(), "print_position");
+            
+            // There's no need to sync position here to continue the simulation, since position wasn't modified.
+        }
+
+        // Calculate forces
+        //-------------------------------------------------------------------------------------------
+        pairs_sim->update_cells(t);
+        pairs_sim->gravity(); 
+        pairs_sim->spring_dashpot(); 
+
+        // Change gravitational force on particle pUid
+        //-------------------------------------------------------------------------------------------
+        ac->syncUid(PairsAccessor::Host);
+
+        if(pIsLocalInMyRank(pUid)){
+            std::cout << "Force Timestep (" << t << "): Particle " << pUid << " is in rank " << pairs_sim->rank() << std::endl;
+            int idx = ac->uidToIdxLocal(pUid);
+
+            // Up-to-date force and mass might be on host or device. 
+            // So sync them in Device before accessing them on device. (No data will be transfered if they are already on device)
+            ac->syncForce(PairsAccessor::Device);
+            ac->syncMass(PairsAccessor::Device);
+
+            // Modify force from device:
+            change_gravitational_force<<<1,1>>>(*ac, idx);
+            checkCudaError(cudaDeviceSynchronize(), "change_gravitational_force");
+
+            // Force on device was modified.
+            // So sync force before continuing the simulation.
+            ac->syncForce(PairsAccessor::Host);
+            std::cout << "Force [from host] after changing = (" 
+                    << ac->getForce(idx)[0] << ", "
+                    << ac->getForce(idx)[1] << ", " 
+                    << ac->getForce(idx)[2] << ")" << std::endl;
+        }
+
+        // Euler
+        //-------------------------------------------------------------------------------------------
+        pairs_sim->euler(dt);
+        pairs_sim->reset_volatiles(); 
+
+        // Communicate
+        //-------------------------------------------------------------------------------------------
+        pairs_sim->communicate(t);
+        // PairsAccessor requires an update when particles are communicated
+        ac->update();
+
+        pairs_sim->vtk_write("output/dem_sd_local", 0, ac->nlocal(), t, vtk_freq);
+        pairs_sim->vtk_write("output/dem_sd_ghost", ac->nlocal(), ac->size(), t, vtk_freq);
+    }
+
+    pairs_sim->end();
+}
\ No newline at end of file
diff --git a/examples/modular/spring_dashpot.py b/examples/modular/spring_dashpot.py
index 1fb833d..191c000 100644
--- a/examples/modular/spring_dashpot.py
+++ b/examples/modular/spring_dashpot.py
@@ -71,11 +71,6 @@ elif target == 'cpu':
 else:
     print(f"Invalid target, use {sys.argv[0]} <cpu/gpu>")
 
-gravity_SI = 9.81
-diameter = 100      # required for linkedCellWidth. TODO: set linkedCellWidth at runtime
-linkedCellWidth = 1.01 * diameter
-ntypes = 2
-
 psim.add_position('position')
 psim.add_property('mass', pairs.real())
 psim.add_property('linear_velocity', pairs.vector())
@@ -88,6 +83,7 @@ psim.add_property('inv_inertia', pairs.matrix())
 psim.add_property('rotation_matrix', pairs.matrix())
 psim.add_property('rotation', pairs.quaternion())
 
+ntypes = 2
 psim.add_feature('type', ntypes)
 psim.add_feature_property('type', 'stiffness', pairs.real(), [3000 for i in range(ntypes * ntypes)])
 psim.add_feature_property('type', 'damping_norm', pairs.real(), [10.0 for i in range(ntypes * ntypes)])
@@ -96,13 +92,15 @@ psim.add_feature_property('type', 'friction', pairs.real())
 
 psim.set_domain_partitioner(pairs.block_forest())
 psim.pbc([False, False, False])
-psim.build_cell_lists(linkedCellWidth)
+psim.build_cell_lists()
 
 # The order of user-defined functions is not important here since 
 # they are not used by other subroutines and are only callable individually 
 psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf })
-psim.compute(spring_dashpot, linkedCellWidth)
+psim.compute(spring_dashpot)
 psim.compute(euler, parameters={'dt': pairs.real()})
+
+gravity_SI = 9.81
 psim.compute(gravity, symbols={'gravity_SI': gravity_SI })
 
 psim.generate()
diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
index 0e52dee..6255ff4 100644
--- a/src/pairs/code_gen/interface.py
+++ b/src/pairs/code_gen/interface.py
@@ -93,6 +93,14 @@ class InterfaceModules:
     @pairs_interface_block
     def setup_sim(self):
         self.sim.module_name('setup_sim')
+        
+        if self.sim.cell_lists.runtime_spacing:
+            for d in range(self.sim.dims):
+                Assign(self.sim, self.sim.cell_lists.spacing[d], Parameter(self.sim, f'cell_spacing_d{d}', Types.Real))
+
+        if self.sim.cell_lists.runtime_cutoff_radius:
+            Assign(self.sim, self.sim.cell_lists.cutoff_radius, Parameter(self.sim, 'cutoff_radius', Types.Real))
+
         self.sim.add_statement(self.sim.setup_particles)
         self.sim.add_statement(UpdateDomain(self.sim))
         self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))
diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py
index e85aa06..a6a156a 100644
--- a/src/pairs/ir/math.py
+++ b/src/pairs/ir/math.py
@@ -1,6 +1,7 @@
 from pairs.ir.ast_term import ASTTerm
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.types import Types
+from pairs.ir.lit import Lit
 
 
 class MathFunction(ASTTerm):
@@ -115,6 +116,7 @@ class Cos(MathFunction):
 
 class Ceil(MathFunction):
     def __init__(self, sim, expr):
+        expr = Lit.cvt(sim, expr)
         assert Types.is_real(expr.type()), "Expression must be of real type!"
         super().__init__(sim)
         self._params = [expr]
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index e8b2d2e..4038f48 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -17,15 +17,31 @@ from pairs.sim.lowerable import Lowerable
 
 
 class CellLists:
-    def __init__(self, sim, dom_part, spacing, cutoff_radius):
+    def __init__(self, sim, dom_part, spacing=None, cutoff_radius=None):
         self.sim = sim
         self.dom_part = dom_part
-        self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())]
-        self.cutoff_radius = cutoff_radius
-        self.nneighbor_cells = [math.ceil(cutoff_radius / self.spacing[d]) for d in range(sim.ndims())]
-        self.nstencil_max = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(sim.ndims())])
+
+        # Cell spacing and cutoff radius can be set at runtime 
+        # only if they haven't been pre-set in the input script
+        if spacing:
+            self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())]
+            self.runtime_spacing = False
+        else:
+            assert self.sim._generate_whole_program==False, "Cell spacing needs to be defined when generating whole program."
+            self.spacing = self.sim.add_array('spacing', self.sim.ndims(), Types.Real)
+            self.runtime_spacing = True
+
+        if cutoff_radius:
+            self.cutoff_radius = cutoff_radius
+            self.runtime_cutoff_radius = False
+        else:
+            assert self.sim._generate_whole_program==False, "cutoff_radius needs to be defined when generating whole program."
+            self.cutoff_radius = self.sim.add_var('cutoff_radius', Types.Real)
+            self.runtime_cutoff_radius = True
+
         # Data introduced in the simulation
         self.nstencil           =   self.sim.add_var('nstencil', Types.Int32)
+        self.nstencil_capacity  =   self.sim.add_var('nstencil_capacity', Types.Int32, 27)
         self.ncells             =   self.sim.add_var('ncells', Types.Int32, 1)
         self.ncells_capacity    =   self.sim.add_var('ncells_capacity', Types.Int32, 100000)
         self.cell_capacity      =   self.sim.add_var('cell_capacity', Types.Int32, 64)
@@ -34,7 +50,7 @@ class CellLists:
         self.cell_particles     =   self.sim.add_array('cell_particles', [self.ncells_capacity, self.cell_capacity], Types.Int32)
         self.cell_sizes         =   self.sim.add_array('cell_sizes', self.ncells_capacity, Types.Int32)
         self.nshapes            =   self.sim.add_array('nshapes', [self.ncells_capacity, self.sim.max_shapes()], Types.Int32)
-        self.stencil            =   self.sim.add_array('stencil', self.nstencil_max, Types.Int32)
+        self.stencil            =   self.sim.add_array('stencil', self.nstencil_capacity, Types.Int32)
         self.particle_cell      =   self.sim.add_array('particle_cell', self.sim.particle_capacity, Types.Int32)
 
         if sim._store_neighbors_per_cell:
@@ -52,8 +68,9 @@ class BuildCellListsStencil(Lowerable):
     def lower(self):
         stencil = self.cell_lists.stencil
         nstencil = self.cell_lists.nstencil
+        nstencil_capacity = self.cell_lists.nstencil_capacity
         spacing = self.cell_lists.spacing
-        nneighbor_cells = self.cell_lists.nneighbor_cells
+        cutoff_radius = self.cell_lists.cutoff_radius
         dim_ncells = self.cell_lists.dim_ncells
         ncells = self.cell_lists.ncells
         ncells_capacity = self.cell_lists.ncells_capacity
@@ -63,6 +80,7 @@ class BuildCellListsStencil(Lowerable):
 
         self.sim.module_name("build_cell_lists_stencil")
         self.sim.check_resize(ncells_capacity, ncells)
+        self.sim.check_resize(nstencil_capacity, nstencil)
 
         for s in range(self.sim.max_shapes()):
             Assign(self.sim, shapes_buffer[s], self.sim.get_shape_id(s))
@@ -79,7 +97,7 @@ class BuildCellListsStencil(Lowerable):
             Assign(self.sim, nstencil, 0)
 
             for dim in range(self.sim.ndims()):
-                nneigh = nneighbor_cells[dim]
+                nneigh = Ceil(self.sim,(cutoff_radius / spacing[dim]))
                 for dim_offset in For(self.sim, -nneigh, nneigh + 1):
                     index = dim_offset if index is None else index * dim_ncells[dim] + dim_offset
                     if dim == self.sim.ndims() - 1:
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index f58082e..2c0e194 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -329,14 +329,16 @@ class Simulation:
             DEMSCGrid(self, xmax, ymax, zmax, spacing, diameter, min_diameter, max_diameter,
                       initial_velocity, particle_density, ntypes))
 
-    def build_cell_lists(self, spacing, store_neighbors_per_cell=False):
-        """Add routines to build the linked-cells acceleration structure"""
+    def build_cell_lists(self, spacing=None, store_neighbors_per_cell=False):
+        """Add routines to build the linked-cells acceleration structure.
+        Leave spacing as None so it can be set at runtime."""
         self._store_neighbors_per_cell = store_neighbors_per_cell
         self.cell_lists = CellLists(self, self._dom_part, spacing, spacing)
         return self.cell_lists
 
-    def build_neighbor_lists(self, spacing):
-        """Add routines to build the Verlet Lists acceleration structure"""
+    def build_neighbor_lists(self, spacing=None):
+        """Add routines to build the Verlet Lists acceleration structure.
+        Leave spacing as None so it can be set at runtime."""
 
         assert self._store_neighbors_per_cell is False, \
             "Using neighbor-lists with store_neighbors_per_cell option is invalid."
-- 
GitLab