From 38a012c7405f17f7052b6146e0eeefdbb55206ff Mon Sep 17 00:00:00 2001
From: Behzad Safaei <iwia103h@alex2.nhr.fau.de>
Date: Tue, 28 Jan 2025 12:54:54 +0100
Subject: [PATCH] Add support for custrom args in user-defined modules

---
 examples/dem_sd.py            |   7 +-
 examples/main.cpp             | 164 ++++++++++++++++----------
 src/pairs/analysis/devices.py |   6 +-
 src/pairs/analysis/modules.py |   5 +
 src/pairs/code_gen/cgen.py    | 211 +++++++++++++++++++++-------------
 src/pairs/ir/declaration.py   |   2 +-
 src/pairs/ir/device.py        |   3 +
 src/pairs/ir/functions.py     |   8 ++
 src/pairs/ir/kernel.py        |  16 +++
 src/pairs/ir/module.py        |  26 ++++-
 src/pairs/ir/parameters.py    |  18 +++
 src/pairs/mapping/funcs.py    |  12 +-
 src/pairs/sim/simulation.py   |  55 +++++++--
 13 files changed, 370 insertions(+), 163 deletions(-)
 create mode 100644 src/pairs/ir/parameters.py

diff --git a/examples/dem_sd.py b/examples/dem_sd.py
index c0e0212..15ee699 100644
--- a/examples/dem_sd.py
+++ b/examples/dem_sd.py
@@ -165,7 +165,7 @@ psim.setup(update_mass_and_inertia, {'densityParticle_SI': densityParticle_SI,
 
 #psim.compute_half()
 psim.build_cell_lists(linkedCellWidth)
-# psim.vtk_output(f"output/dem_{target}", frequency=visSpacing)
+psim.vtk_output(f"output/dem_{target}", frequency=visSpacing)
 
 psim.compute(gravity,
              symbols={'densityParticle_SI': densityParticle_SI,
@@ -175,11 +175,10 @@ psim.compute(gravity,
 
 psim.compute(linear_spring_dashpot,
              linkedCellWidth,
-             symbols={'dt': dt_SI,
-                      'pi': math.pi,
+             symbols={'pi': math.pi,
                       'kappa': kappa,
                       'lnDryResCoeff': lnDryResCoeff,
                       'collisionTime_SI': collisionTime_SI})
 
-psim.compute(euler, symbols={'dt': dt_SI})
+psim.compute(euler, parameters={'dt' : pairs.real()})
 psim.generate()
diff --git a/examples/main.cpp b/examples/main.cpp
index ea9094f..0bae494 100644
--- a/examples/main.cpp
+++ b/examples/main.cpp
@@ -5,28 +5,29 @@
 #include <blockforest/Initialization.h>
 
 int main(int argc, char **argv) {
+    double dt = 5e-5;
     auto pairs_sim = std::make_shared<PairsSimulation>();
     auto pairs_acc = std::make_shared<PairsAccessor>(pairs_sim);
 
     // Create forest (make sure to use_domain(forest)) ----------------------------------------------
-    walberla::math::AABB domain(0, 0, 0, 0.1, 0.1, 0.1);
-    std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance();
-    mpiManager->initializeMPI(&argc, &argv);
-    mpiManager->useWorldComm();
-    auto rank = mpiManager->rank();
-    auto procs = mpiManager->numProcesses();
-    auto block_config = walberla::Vector3<int>(2, 2, 2);
-    auto ref_level = 0;
-    std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest(
-            domain, block_config, walberla::Vector3<bool>(false, false, false), procs, ref_level);
+    // walberla::math::AABB domain(0, 0, 0, 0.1, 0.1, 0.1);
+    // std::shared_ptr<walberla::mpi::MPIManager> mpiManager = walberla::mpi::MPIManager::instance();
+    // mpiManager->initializeMPI(&argc, &argv);
+    // mpiManager->useWorldComm();
+    // auto procs = mpiManager->numProcesses();
+    // auto block_config = walberla::Vector3<int>(2, 2, 2);
+    // auto ref_level = 0;
+    // std::shared_ptr<walberla::BlockForest> forest = walberla::blockforest::createBlockForest(
+    //         domain, block_config, walberla::Vector3<bool>(false, false, false), procs, ref_level);
     //-----------------------------------------------------------------------------------------------
 
     // initialize pairs data structures ----------------------------------------------
     pairs_sim->initialize();
 
     // either create new domain or use an existing one ----------------------------------------
-    // pairs_sim->create_domain(argc, argv);
-    pairs_sim->use_domain(forest);
+    pairs_sim->create_domain(argc, argv);
+
+    // pairs_sim->use_domain(forest);
 
     // create planes and particles ------------------------------------------------------------
     pairs_sim->create_halfspace(0,     0,      0,      1, 0, 0,        0, 13);
@@ -36,23 +37,26 @@ int main(int argc, char **argv) {
     pairs_sim->create_halfspace(0.1,   0.1,    0.1,    0, -1, 0,       0, 13);
     pairs_sim->create_halfspace(0.1,   0.1,    0.1,    0, 0, -1,       0, 13);
 
-    pairs::id_t pUid = pairs_sim->create_sphere(0.0499,   0.0499,   0.07,   0.5, 0.5, 0 ,   1000, 0.0045, 0, 0);
-    pairs::id_t pUid2 = pairs_sim->create_sphere(0.0499,   0.0499,   0.0499,   0.5, 0.5, 0 ,   1000, 0.0045, 0, 0);
+    // pairs::id_t pUid = pairs_sim->create_sphere(0.0499,   0.0499,   0.07,   0.5, 0.5, 0 ,   1000, 0.0045, 0, 0);
+    // pairs::id_t pUid2 = pairs_sim->create_sphere(0.0499,   0.0499,   0.0499,   0.5, 0.5, 0 ,   1000, 0.0045, 0, 0);
 
     // Tracking a particle ------------------------------------------------------------------------
     // if (pUid != pairs_acc->getInvalidUid()){
     //     std::cout<< "Particle " << pUid << " is created in rank " << rank << std::endl;
     // }
     
-    walberla::mpi::allReduceInplace(pUid, walberla::mpi::SUM);
-    walberla::mpi::allReduceInplace(pUid2, walberla::mpi::SUM);
+    // walberla::mpi::allReduceInplace(pUid, walberla::mpi::SUM);
+    // walberla::mpi::allReduceInplace(pUid2, walberla::mpi::SUM);
+    // MPI_Allreduce(MPI_IN_PLACE, &pUid, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
+    // MPI_Allreduce(MPI_IN_PLACE, &pUid2, 1, MPI_LONG_LONG_INT, MPI_SUM, MPI_COMM_WORLD);
+
 
     // if (pUid != pairs_acc->getInvalidUid()){
     //     std::cout<< "Particle " << pUid << " will be tracked by rank " << rank << std::endl;
     // }
 
-    auto pIsLocalInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdx(uid) != pairs_acc->getInvalidIdx();};
-    auto pIsGhostInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdxGhost(uid) != pairs_acc->getInvalidIdx();};
+    // auto pIsLocalInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdx(uid) != pairs_acc->getInvalidIdx();};
+    // auto pIsGhostInMyRank = [&](pairs::id_t uid){return pairs_acc->uidToIdxGhost(uid) != pairs_acc->getInvalidIdx();};
 
 
     // TODO: make sure linkedCellWidth is larger than max diameter in the system
@@ -74,55 +78,95 @@ int main(int argc, char **argv) {
     //     pairs_acc->setLinearVelocity(idx, walberla::Vector3<double>(0.5, 0.5, 0.5));
     // }
 
-    for (int t=0; t<1; ++t){
+    for (int t=0; t<5000; ++t){
         // if ((t%200==0) && (pIsLocalInMyRank(pUid))){
         //     int idx = pairs_acc->uidToIdx(pUid);
         //     std::cout<< "Tracked particle is now in rank " << rank << " --- " << pairs_acc->getPosition(idx)<< std::endl;
         // }
 
         pairs_sim->communicate(t);
+
+        // if(pairs_sim->rank() == 0){
+        //     pairs_sim->pairs_runtime->copyPropertyToHost(0, ReadOnly, (((pairs_sim->pobj->nlocal + pairs_sim->pobj->nghost) * 1) * sizeof(unsigned long long int))); // uid
+        //     for(int i=0; i<pairs_sim->pobj->nlocal; ++i){
+        //         std::cout << "rank ##0## local uid[" << i << "] = " << pairs_sim->pobj->uid[i] << std::endl;
+        //     }
+        // }
+        // if(pairs_sim->rank() == 2){
+        //     pairs_sim->pairs_runtime->copyPropertyToHost(0, ReadOnly, (((pairs_sim->pobj->nlocal + pairs_sim->pobj->nghost) * 1) * sizeof(unsigned long long int))); // uid
+        //     for(int i=0; i<pairs_sim->pobj->nlocal; ++i){
+        //         std::cout << "rank 2 local uid[" << i << "] = " << pairs_sim->pobj->uid[i] << std::endl;
+        //     }
+        //     for(int i=pairs_sim->pobj->nlocal; i<pairs_sim->pobj->nlocal + pairs_sim->pobj->nghost; ++i){
+        //         std::cout << "rank 2 ghost uid[" << i << "] = " << pairs_sim->pobj->uid[i] << std::endl;
+        //     }
+        // }
+
+        // std::cout << pairs_sim->rank() << " --------- nlocal = " << pairs_acc->nlocal() << " nghost = " << pairs_acc->nghost() << std::endl;
+
+        // if (pairs_sim->rank() == 2) {
+        //     for(int i=0; i<pairs_acc->nlocal(); ++i){
+        //         std::cout << pairs_sim->rank() << " Local  " << i << " -- uid= " << pairs_acc->getUid(i) << std::endl; 
+        //     }
+        //     for(int i=pairs_acc->nlocal(); i<pairs_acc->size(); ++i){
+        //         std::cout << pairs_sim->rank() << " Ghost -- " << i << " -- uid= " << pairs_acc->getUid(i) << std::endl; 
+        //     }
+        // }
+
+
+
+        // if (pIsLocalInMyRank(pUid)){
+        //     int idx = pairs_acc->uidToIdx(pUid);
+        //     pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1));
+        //     pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(2, 2, 2));
+        // }
+        // if (pIsLocalInMyRank(pUid2)){
+        //     int idx = pairs_acc->uidToIdx(pUid2);
+        //     pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(10, 10, 10));
+        //     pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(20, 20, 20));
+        // }
         
-        if (pIsLocalInMyRank(pUid)){
-            int idx = pairs_acc->uidToIdx(pUid);
-            pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1));
-            pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(2, 2, 2));
-        }
-        if (pIsLocalInMyRank(pUid2)){
-            int idx = pairs_acc->uidToIdx(pUid2);
-            pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(10, 10, 10));
-            pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(20, 20, 20));
-        }
-
-        pairs_sim->do_timestep(t);
-
-        if (pIsGhostInMyRank(pUid)){
-            int idx = pairs_acc->uidToIdxGhost(pUid);
-            pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1));
-            pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(3, 3, 3));
-        }
-        if (pIsGhostInMyRank(pUid2)){
-            int idx = pairs_acc->uidToIdxGhost(pUid2);
-            pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(2, 2, 2));
-            pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(4, 4, 4));
-        }
-
-        std::cout << "reverse comm and reduce" << std::endl;
-        pairs_sim->reverse_comm();
-
-        if (pIsLocalInMyRank(pUid)){
-            int idx = pairs_acc->uidToIdx(pUid);
-            auto forceSum = pairs_acc->getHydrodynamicForce(idx);
-            auto torqueSum = pairs_acc->getHydrodynamicTorque(idx);
-            std::cout << pUid << " @@@@ reduced force = " << forceSum << std::endl;
-            std::cout << pUid << " @@@@ reduced torque = " << torqueSum << std::endl;
-        }
-        if (pIsLocalInMyRank(pUid2)){
-            int idx = pairs_acc->uidToIdx(pUid2);
-            auto forceSum = pairs_acc->getHydrodynamicForce(idx);
-            auto torqueSum = pairs_acc->getHydrodynamicTorque(idx);
-            std::cout << pUid2 << " @@@@ reduced force = " << forceSum << std::endl;
-            std::cout << pUid2 << " @@@@ reduced torque = " << torqueSum << std::endl;
-        }
+        pairs_sim->update_cells(); 
+        pairs_sim->gravity(); 
+        pairs_sim->linear_spring_dashpot(); 
+        pairs_sim->euler(dt); 
+
+        // std::cout << "reset_volatiles" << std::endl;
+        pairs_sim->reset_volatiles(); 
+
+        pairs_sim->vtk_write("output/dem_sd_local", 0, pairs_acc->nlocal(), t, 200);
+        pairs_sim->vtk_write("output/dem_sd_ghost", pairs_acc->nlocal(), pairs_acc->size(), t, 200);
+
+        // if (pIsGhostInMyRank(pUid)){
+        //     std::cout << pairs_sim->rank() << " - ghost 1: " << pUid << std::endl;
+        //     int idx = pairs_acc->uidToIdxGhost(pUid);
+        //     pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(1, 1, 1));
+        //     pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(3, 3, 3));
+        // }
+        // if (pIsGhostInMyRank(pUid2)){
+        //     std::cout << pairs_sim->rank() << " - ghost 2: " << pUid2 << std::endl;
+        //     int idx = pairs_acc->uidToIdxGhost(pUid2);
+        //     pairs_acc->setHydrodynamicForce(idx, walberla::Vector3<double>(2, 2, 2));
+        //     pairs_acc->setHydrodynamicTorque(idx, walberla::Vector3<double>(4, 4, 4));
+        // }
+
+        // std::cout << "reverse comm and reduce" << std::endl;
+        // pairs_sim->reverse_comm();
+
+        // if (pIsLocalInMyRank(pUid)){
+        //     int idx = pairs_acc->uidToIdx(pUid);
+        //     auto forceSum = pairs_acc->getHydrodynamicForce(idx);
+        //     auto torqueSum = pairs_acc->getHydrodynamicTorque(idx);
+        //     std::cout << pUid << " @@@@ reduced force = " << forceSum << std::endl;
+        //     std::cout << pUid << " @@@@ reduced torque = " << torqueSum << std::endl;
+        // }
+        // if (pIsLocalInMyRank(pUid2)){
+        //     int idx = pairs_acc->uidToIdx(pUid2);
+        //     auto forceSum = pairs_acc->getHydrodynamicForce(idx);
+        //     auto torqueSum = pairs_acc->getHydrodynamicTorque(idx);
+        //     std::cout << pUid2 << " @@@@ reduced force = " << forceSum << std::endl;
+        //     std::cout << pUid2 << " @@@@ reduced torque = " << torqueSum << std::endl;
+        // }
     }
 
     pairs_sim->end();
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 23a3bed..1c009af 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -196,7 +196,11 @@ class FetchKernelReferences(Visitor):
             # Variables only have a device version when changed within kernels
             if self.writing:
                 ast_node.device_flag = True
-    
+
+    def visit_Parameter(self, ast_node):
+        for k in self.kernel_stack:
+            k.add_parameter(ast_node, self.writing)
+
     def visit_Iter(self, ast_node):
         for k in self.kernel_stack:
             if ast_node.is_ref_candidate():
diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py
index 4cf0bf8..24ab2be 100644
--- a/src/pairs/analysis/modules.py
+++ b/src/pairs/analysis/modules.py
@@ -115,3 +115,8 @@ class FetchModulesReferences(Visitor):
         for m in self.module_stack:
             if not ast_node.temporary():
                 m.add_variable(ast_node, self.writing)
+
+    def visit_Parameter(self, ast_node):
+        for m in self.module_stack:
+            # parameters are restricted to read-only, passed by value
+            m.add_parameter(ast_node, write=False)
\ No newline at end of file
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index c9e2c2e..27093c5 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -28,6 +28,7 @@ from pairs.ir.sizeof import Sizeof
 from pairs.ir.types import Types
 from pairs.ir.print import Print, PrintCode
 from pairs.ir.variables import Var, DeclareVariable, Deref
+from pairs.ir.parameters import Parameter
 from pairs.ir.vectors import Vector, VectorAccess, VectorOp, ZeroVector
 from pairs.sim.domain_partitioners import DomainPartitioners
 from pairs.sim.timestep import Timestep
@@ -175,10 +176,19 @@ class CGen:
         self.print("")
 
         self.print("// Module headers")
-        for module in self.sim.modules():
-            if module.name != "main":
-                self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj);")
 
+        self.print("namespace pairs::internal {")
+        self.print.add_indent(4)
+
+        for module in self.sim.modules():
+            if module.name != "main" and not module.user_defined:
+                module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}"
+                                                for param in module.parameters())
+                module_params = ", " + module_params if module_params else ""
+                self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params});")
+        
+        self.print.add_indent(-4)
+        self.print("}")
         self.print("")
 
     def generate_host_pairs_accessor_class(self):
@@ -370,7 +380,7 @@ class CGen:
 
         self.print.end()
 
-    def generate_library(self, initialize_module, create_domain_module, setup_sim_module,  do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module):
+    def generate_library(self, update_cells_module, user_defined_modules, initialize_module, create_domain_module, setup_sim_module,  do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module):
         self.generate_interfaces()
         # Generate CUDA/CPP file with modules
         ext = ".cu" if self.target.is_gpu() else ".cpp"
@@ -378,6 +388,7 @@ class CGen:
         self.print.start()
         self.generate_preamble()
         self.print(f"#include \"{self.ref}.hpp\"")
+        self.print("")
 
         if self.target.is_gpu():
             for array in self.sim.arrays.statics():
@@ -393,13 +404,22 @@ class CGen:
                     tkw = Types.c_keyword(self.sim, t)
                     size = feature_prop.array_size()
                     self.print(f"__constant__ {tkw} d_{feature_prop.name()}[{size}];")
+
+        self.print("")
                     
+        self.print("namespace pairs::internal {")
+        self.print.add_indent(4)
+        
         for kernel in self.sim.kernels():
             self.generate_kernel(kernel)
 
         for module in self.sim.modules():
-            if module.name not in ['initialize', 'create_domain', 'setup_sim', 'do_timestep', 'reverse_comm', 'communicate', 'reset_volatiles']:
-                self.generate_module(module)
+            if module.name not in ['update_cells', 'initialize', 'create_domain', 'setup_sim', 'do_timestep', 'reverse_comm', 'communicate', 'reset_volatiles']:
+                if not module.user_defined:
+                    self.generate_module(module)
+
+        self.print.add_indent(-4)
+        self.print("}")
 
         self.print.end()
 
@@ -420,7 +440,7 @@ class CGen:
 
         self.generate_full_object_names = True
         self.print("class PairsSimulation {")
-        self.print("private:")
+        self.print("public:")
         self.print("    PairsRuntime *pairs_runtime;")
         self.print("    struct PairsObjects *pobj;")
         self.print("    friend class PairsAccessor;")
@@ -428,19 +448,19 @@ class CGen:
 
         self.print.add_indent(4)
 
+        for module in user_defined_modules:
+            self.generate_module(module)
+            self.print("")
+        
         self.print("void initialize() {")
         self.print(f"    pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
         self.print(f"    pobj = new PairsObjects();")
-        self.print.add_indent(4)
         self.generate_statement(initialize_module.block)
-        self.print.add_indent(-4)
         self.print("}")
         self.print("")
 
         self.print("void create_domain(int argc, char **argv) {")
-        self.print.add_indent(4)
         self.generate_statement(create_domain_module.block)
-        self.print.add_indent(-4)
         self.print("}")
         self.print("")
 
@@ -466,39 +486,43 @@ class CGen:
         self.print("")
 
         self.print("void setup_sim() {")
-        self.print.add_indent(4)
         self.generate_statement(setup_sim_module.block)
-        self.print.add_indent(-4)
         self.print("}")
         self.print("")
 
-        self.print("void do_timestep(int timestep) {")
-        self.print("    pobj->sim_timestep = timestep;")
-        self.print.add_indent(4)
-        self.generate_statement(do_timestep_module.block)
-        self.print.add_indent(-4)
+        self.print("void update_cells() {")
+        self.generate_statement(update_cells_module.block)
         self.print("}")
         self.print("")
 
+        # self.print("void do_timestep(int timestep) {")
+        # self.print("    pobj->sim_timestep = timestep;")
+        # self.print.add_indent(4)
+        # self.generate_statement(do_timestep_module.block)
+        # self.print.add_indent(-4)
+        # self.print("}")
+        # self.print("")
+
+
+
         self.print("void reverse_comm() {")
-        self.print.add_indent(4)
         self.generate_statement(reverse_comm_module.block)
-        self.print.add_indent(-4)
         self.print("}")
         self.print("")
 
         self.print("void communicate(int timestep) {")
         self.print("    pobj->sim_timestep = timestep;")
-        self.print.add_indent(4)
         self.generate_statement(communicate_module.block)
-        self.print.add_indent(-4)
         self.print("}")
         self.print("")
 
         self.print("void reset_volatiles() {")
-        self.print.add_indent(4)
         self.generate_statement(reset_volatiles_module.block)
-        self.print.add_indent(-4)
+        self.print("}")
+        self.print("")
+
+        self.print("void vtk_write(const char* filename, int start, int end, int timestep, int frequency) {")
+        self.print("    pairs::vtk_write_data(pairs_runtime, filename, start, end, timestep, frequency);")
         self.print("}")
         self.print("")
 
@@ -519,6 +543,62 @@ class CGen:
         self.print.end()
         self.generate_full_object_names = False
 
+    def generate_module_declerations(self, module):
+        device_cond = module.run_on_device and self.target.is_gpu()
+
+        for var in module.read_only_variables():
+            type_kw = Types.c_keyword(self.sim, var.type())
+            self.print(f"{type_kw} {var.name()} = pobj->{var.name()};")
+
+        for var in module.write_variables():
+            type_kw = Types.c_keyword(self.sim, var.type())
+
+            if device_cond and var.device_flag:
+                self.print(f"{type_kw} *{var.name()} = pobj->rv_{var.name()}.getDevicePointer();")
+            elif var.force_read:
+                self.print(f"{type_kw} {var.name()} = pobj->{var.name()};")
+            else:
+                self.print(f"{type_kw} *{var.name()} = &(pobj->{var.name()});")
+
+        for array in module.arrays():
+            type_kw = Types.c_keyword(self.sim, array.type())
+            name = array.name() if not device_cond else f"d_{array.name()}"
+            if not array.is_static() or (array.is_static() and not device_cond):
+                self.print(f"{type_kw} *{array.name()} = pobj->{name};")
+
+            if array in module.host_references():
+                self.print(f"{type_kw} *h_{array.name()} = pobj->{array.name()};")
+
+
+        for prop in module.properties():
+            type_kw = Types.c_keyword(self.sim, prop.type())
+            name = prop.name() if not device_cond else f"d_{prop.name()}"
+            self.print(f"{type_kw} *{prop.name()} = pobj->{name};")
+
+            if prop in module.host_references():
+                self.print(f"{type_kw} *h_{prop.name()} = pobj->{prop.name()};")
+
+        for contact_prop in module.contact_properties():
+            type_kw = Types.c_keyword(self.sim, contact_prop.type())
+            name = contact_prop.name() if not device_cond else f"d_{contact_prop.name()}"
+            self.print(f"{type_kw} *{contact_prop.name()} = pobj->{name};")
+
+            if contact_prop in module.host_references():
+                self.print(f"{type_kw} *h_{contact_prop.name()} = pobj->{contact_prop.name()};")
+
+        for feature_prop in module.feature_properties():
+            type_kw = Types.c_keyword(self.sim, feature_prop.type())
+            name = feature_prop.name() if not device_cond else f"d_{feature_prop.name()}"
+
+            if feature_prop.device_flag and device_cond:
+                # self.print(f"{type_kw} *{feature_prop.name()} = {self.generate_object_reference(feature_prop, device=device_cond)};")
+                continue
+            else:
+                self.print(f"{type_kw} *{feature_prop.name()} = pobj->{name};")
+
+            if feature_prop in module.host_references():
+                self.print(f"{type_kw} *h_{feature_prop.name()} = pobj->{feature_prop.name()};")
+
     def generate_module(self, module):
         if module.name == 'main':
             ndims = module.sim.ndims()
@@ -549,65 +629,23 @@ class CGen:
             self.generate_full_object_names = False
 
         else:
-            self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj) {{")
+            module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}"
+                                                for param in module.parameters())
+            if not module.user_defined:
+                module_params = ", " + module_params if module_params else ""
+                self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{")
+            else:
+                
+                self.print(f"void {module.name}({module_params}) {{")
+
+
             self.print.add_indent(4)
-            device_cond = module.run_on_device and self.target.is_gpu()
 
             if self.debug:
                 self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
 
-            for var in module.read_only_variables():
-                type_kw = Types.c_keyword(self.sim, var.type())
-                self.print(f"{type_kw} {var.name()} = pobj->{var.name()};")
-
-            for var in module.write_variables():
-                type_kw = Types.c_keyword(self.sim, var.type())
-
-                if device_cond and var.device_flag:
-                    self.print(f"{type_kw} *{var.name()} = pobj->rv_{var.name()}.getDevicePointer();")
-                elif var.force_read:
-                    self.print(f"{type_kw} {var.name()} = pobj->{var.name()};")
-                else:
-                    self.print(f"{type_kw} *{var.name()} = &(pobj->{var.name()});")
-
-            for array in module.arrays():
-                type_kw = Types.c_keyword(self.sim, array.type())
-                name = array.name() if not device_cond else f"d_{array.name()}"
-                if not array.is_static() or (array.is_static() and not device_cond):
-                    self.print(f"{type_kw} *{array.name()} = pobj->{name};")
-
-                if array in module.host_references():
-                    self.print(f"{type_kw} *h_{array.name()} = pobj->{array.name()};")
-
-
-            for prop in module.properties():
-                type_kw = Types.c_keyword(self.sim, prop.type())
-                name = prop.name() if not device_cond else f"d_{prop.name()}"
-                self.print(f"{type_kw} *{prop.name()} = pobj->{name};")
-
-                if prop in module.host_references():
-                    self.print(f"{type_kw} *h_{prop.name()} = pobj->{prop.name()};")
-
-            for contact_prop in module.contact_properties():
-                type_kw = Types.c_keyword(self.sim, contact_prop.type())
-                name = contact_prop.name() if not device_cond else f"d_{contact_prop.name()}"
-                self.print(f"{type_kw} *{contact_prop.name()} = pobj->{name};")
-
-                if contact_prop in module.host_references():
-                    self.print(f"{type_kw} *h_{contact_prop.name()} = pobj->{contact_prop.name()};")
-
-            for feature_prop in module.feature_properties():
-                type_kw = Types.c_keyword(self.sim, feature_prop.type())
-                name = feature_prop.name() if not device_cond else f"d_{feature_prop.name()}"
-
-                if feature_prop.device_flag and device_cond:
-                    # self.print(f"{type_kw} *{feature_prop.name()} = {self.generate_object_reference(feature_prop, device=device_cond)};")
-                    continue
-                else:
-                    self.print(f"{type_kw} *{feature_prop.name()} = pobj->{name};")
-
-                if feature_prop in module.host_references():
-                    self.print(f"{type_kw} *h_{feature_prop.name()} = pobj->{feature_prop.name()};")
+            if not module.user_defined:
+                self.generate_module_declerations(module)
 
             self.print.add_indent(-4)
             self.generate_statement(module.block)
@@ -616,6 +654,11 @@ class CGen:
     def generate_kernel(self, kernel):
         kernel_params = "int range_start"
         has_resizes = False
+        for param in kernel.parameters():
+            type_kw = Types.c_keyword(self.sim, param.type())
+            decl = f"{type_kw} {param.name()}"
+            kernel_params += f", {decl}"
+
         for var in kernel.read_only_variables():
             type_kw = Types.c_keyword(self.sim, var.type())
             decl = f"{type_kw} {var.name()}"
@@ -981,6 +1024,9 @@ class CGen:
             kernel = ast_node.kernel
             kernel_params = f"{range_start}"
 
+            for param in kernel.parameters():
+                kernel_params += f", {param.name()}"
+
             for var in kernel.read_only_variables():
                 kernel_params += f", {var.name()}"
 
@@ -1022,7 +1068,9 @@ class CGen:
             self.print("}")
 
         if isinstance(ast_node, ModuleCall):
-            self.print(f"{ast_node.module.name}(pairs_runtime, pobj);")
+            module_params = ", ".join(f"{param.name()}" for param in ast_node.module.parameters())
+            module_params = ", " + module_params if module_params else ""
+            self.print(f"pairs::internal::{ast_node.module.name}(pairs_runtime, pobj{module_params});")
 
         if isinstance(ast_node, Print):
             args = ast_node.args
@@ -1333,7 +1381,10 @@ class CGen:
 
         if isinstance(ast_node, Var):
             return self.generate_object_reference(ast_node, index=index)
-
+        
+        if isinstance(ast_node, Parameter):
+            return ast_node.name()
+        
         if isinstance(ast_node, VectorAccess):
             return self.generate_expression(ast_node.expr, mem, self.generate_expression(ast_node.index))
 
diff --git a/src/pairs/ir/declaration.py b/src/pairs/ir/declaration.py
index 3e26c57..35992d5 100644
--- a/src/pairs/ir/declaration.py
+++ b/src/pairs/ir/declaration.py
@@ -7,7 +7,7 @@ class Decl(ASTNode):
         self.elem = elem
 
     def __str__(self):
-        return f"Decl<self.elem>"
+        return f"Decl<{self.elem}>"
 
     def children(self):
         return [self.elem]
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index d7cb0d1..4eddecf 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -62,6 +62,9 @@ class CopyProperty(ASTNode):
         self._action = action
         self.sim.add_statement(self)
 
+    def __str__(self):
+        return f"CopyProperty<{self._prop}>"
+    
     def prop(self):
         return self._prop
 
diff --git a/src/pairs/ir/functions.py b/src/pairs/ir/functions.py
index b18406b..efaf3e6 100644
--- a/src/pairs/ir/functions.py
+++ b/src/pairs/ir/functions.py
@@ -11,6 +11,9 @@ class Call(ASTTerm):
         self.params = [Lit.cvt(sim, p) for p in params]
         self.return_type = return_type
 
+    def __str__(self):
+        return f"Call<{self.func_name}, {self.params}>"
+    
     def name(self):
         return self.func_name
 
@@ -28,8 +31,13 @@ class Call_Int(Call):
     def __init__(self, sim, func_name, parameters):
         super().__init__(sim, func_name, parameters, Types.Int32)
 
+    def __str__(self):
+            return f"Call_Int<{self.func_name}, {self.params}>"
 
 class Call_Void(Call):
     def __init__(self, sim, func_name, parameters):
         super().__init__(sim, func_name, parameters, Types.Invalid)
         sim.add_statement(self)
+    
+    def __str__(self):
+        return f"Cal_Void<{self.func_name}, {self.params}>"
diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index 5faaee4..4f477ca 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -8,6 +8,7 @@ from pairs.ir.matrices import MatrixOp
 from pairs.ir.properties import Property, ContactProperty
 from pairs.ir.quaternions import QuaternionOp
 from pairs.ir.variables import Var
+from pairs.ir.parameters import Parameter
 from pairs.ir.vectors import VectorOp
 from pairs.ir.loops import Iter
 
@@ -20,6 +21,7 @@ class Kernel(ASTNode):
         self._id = Kernel.last_kernel
         self._name = name if name is not None else "kernel" + str(Kernel.last_kernel)
         self._variables = {}
+        self._parameters = {}
         self._iters = {}
         self._arrays = {}
         self._properties = {}
@@ -52,6 +54,9 @@ class Kernel(ASTNode):
     def variables(self):
         return self._variables
 
+    def parameters(self):
+        return self._parameters
+    
     def iters(self):
         return self._iters
     
@@ -105,6 +110,17 @@ class Kernel(ASTNode):
                 action = Actions.NoAction if var not in self._variables else self._variables[var]
                 self._variables[var] = Actions.update_rule(action, new_op)
     
+    def add_parameter(self, parameter, write=False):
+        parameter_list = parameter if isinstance(parameter, list) else [parameter]
+        new_op = 'w' if write else 'r'
+
+        for param in parameter_list:
+            assert isinstance(param, Parameter), \
+                "Module.add_parameter(): given element is not of type Parameter!"
+
+            action = Actions.NoAction if param not in self._parameters else self._parameters[param]
+            self._parameters[param] = Actions.update_rule(action, new_op)
+
     def add_iter(self, iter, write=False):
         iter_list = iter if isinstance(iter, list) else [iter]
         new_op = 'w' if write else 'r'
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index ab78942..52ee117 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -4,15 +4,17 @@ from pairs.ir.ast_node import ASTNode
 from pairs.ir.features import FeatureProperty
 from pairs.ir.properties import Property, ContactProperty
 from pairs.ir.variables import Var
+from pairs.ir.parameters import Parameter
 
 
 class Module(ASTNode):
     last_module = 0
 
-    def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False):
+    def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False, user_defined=False):
         super().__init__(sim)
         self._id = Module.last_module
         self._name = name if name is not None else "module" + str(Module.last_module)
+        self._parameters = {}
         self._variables = {}
         self._arrays = {}
         self._properties = {}
@@ -23,6 +25,7 @@ class Module(ASTNode):
         self._resizes_to_check = resizes_to_check
         self._check_properties_resize = check_properties_resize
         self._run_on_device = run_on_device
+        self._user_defined = user_defined
         self._profile = False
         sim.add_module(self)
         Module.last_module += 1
@@ -45,6 +48,10 @@ class Module(ASTNode):
     @property
     def run_on_device(self):
         return self._run_on_device
+    
+    @property
+    def user_defined(self):
+        return self._user_defined
 
     def profile(self):
         self._profile = True
@@ -53,6 +60,9 @@ class Module(ASTNode):
     def must_profile(self):
         return self._profile
 
+    def parameters(self):
+        return self._parameters
+    
     def variables(self):
         return self._variables
 
@@ -99,6 +109,17 @@ class Module(ASTNode):
             action = Actions.NoAction if var not in self._variables else self._variables[var]
             self._variables[var] = Actions.update_rule(action, new_op)
 
+    def add_parameter(self, parameter, write=False):
+        parameter_list = parameter if isinstance(parameter, list) else [parameter]
+        new_op = 'w' if write else 'r'
+
+        for param in parameter_list:
+            assert isinstance(param, Parameter), \
+                "Module.add_parameter(): given element is not of type Parameter!"
+
+            action = Actions.NoAction if param not in self._parameters else self._parameters[param]
+            self._parameters[param] = Actions.update_rule(action, new_op)
+
     def add_property(self, prop, write=False):
         prop_list = prop if isinstance(prop, list) else [prop]
         new_op = 'w' if write else 'r'
@@ -150,5 +171,8 @@ class ModuleCall(ASTNode):
     def module(self):
         return self._module
 
+    def __str__(self):
+        return f"ModuleCall<{self._module}>"
+    
     def children(self):
         return [self._module]
diff --git a/src/pairs/ir/parameters.py b/src/pairs/ir/parameters.py
new file mode 100644
index 0000000..d6b9ab6
--- /dev/null
+++ b/src/pairs/ir/parameters.py
@@ -0,0 +1,18 @@
+from pairs.ir.ast_term import ASTTerm 
+from pairs.ir.operator_class import OperatorClass
+
+
+class Parameter(ASTTerm):
+    def __init__(self, sim, param_name, param_type):
+        super().__init__(sim, OperatorClass.from_type(param_type))
+        self.param_name = param_name
+        self.param_type = param_type
+
+    def __str__(self):
+        return f"Parameter<{self.param_name}>"
+
+    def name(self):
+        return self.param_name
+
+    def type(self):
+        return self.param_type
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index 0632bd2..efcb825 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -7,6 +7,7 @@ from pairs.ir.loops import For, ParticleFor
 from pairs.ir.operators import Operators
 from pairs.ir.operator_class import OperatorClass
 from pairs.ir.properties import ContactProperty
+from pairs.ir.parameters import Parameter
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.types import Types
 from pairs.mapping.keywords import Keywords
@@ -80,9 +81,10 @@ class BuildParticleIR(ast.NodeVisitor):
 
         raise Exception("Invalid operator: {}".format(ast.dump(op)))
 
-    def __init__(self, sim, ctx_symbols={}):
+    def __init__(self, sim, ctx_symbols={}, func_params={}):
         self.sim = sim
         self.ctx_symbols = ctx_symbols.copy()
+        self.func_params = func_params.copy()
         self.keywords = Keywords(sim)
 
     def add_symbols(self, symbols):
@@ -210,6 +212,7 @@ class BuildParticleIR(ast.NodeVisitor):
     def visit_Name(self, node):
         symbol_types = [
             self.ctx_symbols.get,
+            self.func_params.get,
             self.sim.array,
             self.sim.property,
             self.sim.feature_property,
@@ -282,7 +285,7 @@ class BuildParticleIR(ast.NodeVisitor):
         return op_class(self.sim, operand, None, BuildParticleIR.get_unary_op(node.op))
 
 
-def compute(sim, func, cutoff_radius=None, symbols={}, pre_step=False, skip_first=False):
+def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False):
     src = inspect.getsource(func)
     tree = ast.parse(src, mode='exec')
     #print(ast.dump(ast.parse(src, mode='exec')))
@@ -298,6 +301,7 @@ def compute(sim, func, cutoff_radius=None, symbols={}, pre_step=False, skip_firs
 
     # Convert literal symbols
     symbols = {symbol: Lit.cvt(sim, value) for symbol, value in symbols.items()}
+    parameters = {pname: Parameter(sim, pname, ptype) for pname, ptype in parameters.items()}
 
     sim.init_block()
     sim.module_name(func.__name__)
@@ -305,14 +309,14 @@ def compute(sim, func, cutoff_radius=None, symbols={}, pre_step=False, skip_firs
     if nparams == 1:
         for i in ParticleFor(sim):
             for _ in Filter(sim, ScalarOp.cmp(sim.particle_flags[i] & Flags.Fixed, 0)):
-                ir = BuildParticleIR(sim, symbols)
+                ir = BuildParticleIR(sim, symbols, parameters)
                 ir.add_symbols({params[0]: i})
                 ir.visit(tree)
 
     else:
         for interaction_data in ParticleInteraction(sim, nparams, cutoff_radius):
             # Start building IR
-            ir = BuildParticleIR(sim, symbols)
+            ir = BuildParticleIR(sim, symbols, parameters)
             ir.add_symbols({
                 params[0]: interaction_data.i(),
                 params[1]: interaction_data.j(),
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 044a135..5c665e4 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -93,11 +93,13 @@ class Simulation:
         # Different segments of particle code/functions
         self.create_domain = Block(self, [])
         self.setup_particles = Block(self, [])
+        self.module_list = []
+        self.kernel_list = []
+
+        # User-defined modules
         self.setup_functions = []
         self.pre_step_functions = []
         self.functions = []
-        self.module_list = []
-        self.kernel_list = []
 
         # Structures to generated resize code for capacities
         self._check_properties_resize = False
@@ -303,8 +305,8 @@ class Simulation:
         self.neighbor_lists = NeighborLists(self, self.cell_lists)
         return self.neighbor_lists
 
-    def compute(self, func, cutoff_radius=None, symbols={}, pre_step=False, skip_first=False):
-        return compute(self, func, cutoff_radius, symbols, pre_step, skip_first)
+    def compute(self, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False):
+        return compute(self, func, cutoff_radius, symbols, parameters, pre_step, skip_first)
 
     def setup(self, func, symbols={}):
         return setup(self, func, symbols)
@@ -447,8 +449,10 @@ class Simulation:
         # Params that determine when a method must be called only when reneighboring
         every_reneighbor_params = {'every': self.reneighbor_frequency}
 
+        timestep_procedures = []
+
         # First steps executed during each time-step in the simulation
-        timestep_procedures = self.pre_step_functions 
+        timestep_procedures += self.pre_step_functions 
 
         comm_routine = [
             (comm.exchange(), every_reneighbor_params),
@@ -458,21 +462,23 @@ class Simulation:
         if self._generate_whole_program:
             timestep_procedures += comm_routine
 
-        timestep_procedures +=    [
+        update_cells =    [
             (BuildCellLists(self, self.cell_lists), every_reneighbor_params),
             (PartitionCellLists(self, self.cell_lists), every_reneighbor_params)
         ]
 
         # Add routine to build neighbor-lists per cell
         if self._store_neighbors_per_cell:
-            timestep_procedures.append(
+            update_cells.append(
                 (BuildCellNeighborLists(self, self.cell_lists), every_reneighbor_params))
 
         # Add routine to build neighbor-lists per particle (standard Verlet Lists)
         if self.neighbor_lists is not None:
-            timestep_procedures.append(
+            update_cells.append(
                 (BuildNeighborLists(self, self.neighbor_lists), every_reneighbor_params))
 
+        timestep_procedures += update_cells
+
         # Add routines for contact history management
         if self._use_contact_history:
             if self.neighbor_lists is not None:
@@ -548,6 +554,9 @@ class Simulation:
 
         # Generate a small library to be called
         else:
+            update_cells = [m[0] if isinstance(m, tuple) else m for m in update_cells]
+            update_cells_module = Module(self, name='update_cells', block=Block.from_list(self, update_cells))
+
             initialize_module = Module(self, name='initialize', block=inits)
             create_domain_module = Module(self, name='create_domain', block=self.create_domain)
 
@@ -563,12 +572,34 @@ class Simulation:
             communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).as_block())
             reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self)))
 
-            modules_list = [initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module]
+            modules_list = [
+                update_cells_module, 
+                initialize_module, 
+                create_domain_module, 
+                setup_sim_module, 
+                reverse_comm_module, 
+                communicate_module, 
+                reset_volatiles_module
+            ]
 
-            transformations = Transformations(modules_list, self._target)
-            transformations.apply_all()
+            Transformations(modules_list, self._target).apply_all()
+
+            # user defined modules are transformed seperately as indvidual modules 
+            # i.e. they are transformed once again if already transformed in setup_sim or do_timestep
+            user_defined_modules = self.setup_functions + self.pre_step_functions + self.functions
+            user_defined_modules = [m[0] if isinstance(m, tuple) else m for m in user_defined_modules]
+            user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in user_defined_modules]
+            Transformations(user_defined_modules, self._target).apply_all()
 
             # Generate library
-            self.code_gen.generate_library(initialize_module, create_domain_module, setup_sim_module, do_timestep_module, reverse_comm_module, communicate_module, reset_volatiles_module)
+            self.code_gen.generate_library(update_cells_module, 
+                                           user_defined_modules, 
+                                           initialize_module, 
+                                           create_domain_module, 
+                                           setup_sim_module, 
+                                           do_timestep_module, 
+                                           reverse_comm_module, 
+                                           communicate_module, 
+                                           reset_volatiles_module)
 
         self.code_gen.generate_interfaces()
-- 
GitLab