diff --git a/examples/modular/force_reduction.cpp b/examples/modular/force_reduction.cpp
index b34a3451184aa1ab68ba3c613b38e9c065ac25df..e5d2394d108a3068ec86ec8b8026d3ae95c8daca 100644
--- a/examples/modular/force_reduction.cpp
+++ b/examples/modular/force_reduction.cpp
@@ -16,6 +16,7 @@ int main(int argc, char **argv) {
     
     // setup_sim after creating all bodies
     pairs_sim->setup_sim();
+    pairs_sim->update_mass_and_inertia();
 
     // Track particle
     //-------------------------------------------------------------------------------------------
@@ -31,7 +32,7 @@ int main(int argc, char **argv) {
 
     // Communicate particles (exchange/ghost)
     //-------------------------------------------------------------------------------------------
-    pairs_sim->communicate();
+    pairs_sim->communicate(0);
     ac->update();
         
     // Helper lambdas for demo
@@ -89,7 +90,7 @@ int main(int argc, char **argv) {
         
         // Do computations
         //-------------------------------------------------------------------------------------------
-        pairs_sim->update_cells(); 
+        pairs_sim->update_cells(t); 
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot();
         pairs_sim->euler(5e-5); 
diff --git a/examples/modular/force_reduction.py b/examples/modular/force_reduction.py
index 50846cd0795bb7dc77b021528ebc1d62bdebf17d..af9cea7058190b7fda485cc1ec3a6fd6cde8b4f1 100644
--- a/examples/modular/force_reduction.py
+++ b/examples/modular/force_reduction.py
@@ -102,7 +102,9 @@ psim.set_domain_partitioner(pairs.block_forest())
 psim.pbc([False, False, False])
 psim.build_cell_lists(linkedCellWidth)
 
-psim.setup(update_mass_and_inertia, symbols={'infinity': math.inf })
+# The order of user-defined functions is not important here since 
+# they are not used by other subroutines and are only callable individually 
+psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf })
 psim.compute(spring_dashpot, linkedCellWidth)
 psim.compute(gravity, symbols={'gravity_SI': gravity_SI })
 psim.compute(euler, parameters={'dt': pairs.real()})
diff --git a/examples/modular/sd_sample_1_CPU_GPU.cpp b/examples/modular/sd_sample_1_CPU_GPU.cpp
index 03b835c7a61accb652f8ee70f875bd64ba47d757..960ffbaf7dec899ea3c054b0ea8d3191569a440e 100644
--- a/examples/modular/sd_sample_1_CPU_GPU.cpp
+++ b/examples/modular/sd_sample_1_CPU_GPU.cpp
@@ -20,6 +20,7 @@ int main(int argc, char **argv) {
     pairs_sim->create_sphere(0.4, 0.4, 0.68,    2, 2, 0,    1000, 0.05, 0, 0);
 
     pairs_sim->setup_sim();
+    pairs_sim->update_mass_and_inertia();
 
     int num_timesteps = 2000;
     int vtk_freq = 20;
@@ -30,7 +31,7 @@ int main(int argc, char **argv) {
 
         pairs_sim->communicate(t);
         
-        pairs_sim->update_cells(); 
+        pairs_sim->update_cells(t); 
 
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
diff --git a/examples/modular/sd_sample_2_CPU_GPU.cpp b/examples/modular/sd_sample_2_CPU_GPU.cpp
index 330003fe478d75fc273c5426cd8f9504d5be1ff6..e18ce39c3d01a5f7156d727d26a17f9080e7ca22 100644
--- a/examples/modular/sd_sample_2_CPU_GPU.cpp
+++ b/examples/modular/sd_sample_2_CPU_GPU.cpp
@@ -42,6 +42,7 @@ int main(int argc, char **argv) {
     pairs_sim->create_sphere(0.4, 0.4, 0.68,    2, 2, 0,    1000, 0.05, 0, 0);
 
     pairs_sim->setup_sim();
+    pairs_sim->update_mass_and_inertia();
 
     int num_timesteps = 2000;
     int vtk_freq = 20;
@@ -52,7 +53,7 @@ int main(int argc, char **argv) {
 
         pairs_sim->communicate(t);
         
-        pairs_sim->update_cells(); 
+        pairs_sim->update_cells(t); 
 
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
diff --git a/examples/modular/sd_sample_3_CPU.cpp b/examples/modular/sd_sample_3_CPU.cpp
index 315a7a2a8b67023ed2a5440df9f301d2a6c5059c..002e4f907b4c77b403dc0b640dcfde93511514bf 100644
--- a/examples/modular/sd_sample_3_CPU.cpp
+++ b/examples/modular/sd_sample_3_CPU.cpp
@@ -32,8 +32,9 @@ int main(int argc, char **argv) {
     auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();};
 
     pairs_sim->setup_sim();
+    pairs_sim->update_mass_and_inertia();
 
-    pairs_sim->communicate();
+    pairs_sim->communicate(0);
 
     int num_timesteps = 2000;
     int vtk_freq = 20;
@@ -55,7 +56,7 @@ int main(int argc, char **argv) {
 
         // Calculate forces
         //-------------------------------------------------------------------------------------------
-        pairs_sim->update_cells();
+        pairs_sim->update_cells(t);
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
 
diff --git a/examples/modular/sd_sample_3_GPU.cu b/examples/modular/sd_sample_3_GPU.cu
index 6861c080761e3ecbb20c5a75e3e5da8f88135b58..e18bbf0705542f05bac1ab28be06e16fa2527f5b 100644
--- a/examples/modular/sd_sample_3_GPU.cu
+++ b/examples/modular/sd_sample_3_GPU.cu
@@ -65,8 +65,9 @@ int main(int argc, char **argv) {
     auto pIsLocalInMyRank = [&](pairs::id_t uid){return ac->uidToIdxLocal(uid) != ac->getInvalidIdx();};
 
     pairs_sim->setup_sim();
+    pairs_sim->update_mass_and_inertia();
 
-    pairs_sim->communicate();
+    pairs_sim->communicate(0);
     // PairsAccessor requires an update when particles are communicated 
     ac->update();
 
@@ -85,15 +86,15 @@ int main(int argc, char **argv) {
             int idx = ac->uidToIdxLocal(pUid);
 
             // Up-to-date position might be on host or device. 
-            // Sync position on HostAndDevice before printing it from host and device:
-            ac->syncPosition(PairsAccessor::HostAndDevice); 
-
+            // Sync position on Host before reading it from host:
+            ac->syncPosition(PairsAccessor::Host); 
             std::cout << "Position [from host] = (" 
                     << ac->getPosition(idx)[0] << ", "
                     << ac->getPosition(idx)[1] << ", " 
                     << ac->getPosition(idx)[2] << ")" << std::endl;
             
-            // Position is synced on both host and device. Print position from device:
+            // Sync position on Device before reading it from device:
+            ac->syncPosition(PairsAccessor::Device); 
             print_position<<<1,1>>>(*ac, idx);
             checkCudaError(cudaDeviceSynchronize(), "print_position");
             
@@ -102,13 +103,12 @@ int main(int argc, char **argv) {
 
         // Calculate forces
         //-------------------------------------------------------------------------------------------
-        pairs_sim->update_cells();
+        pairs_sim->update_cells(t);
         pairs_sim->gravity(); 
         pairs_sim->spring_dashpot(); 
 
         // Change gravitational force on particle pUid
         //-------------------------------------------------------------------------------------------
-        // Here we are syncing Uid on Host again for clarity, but no data transfer will happen since Uid is already on host
         ac->syncUid(PairsAccessor::Host);
 
         if(pIsLocalInMyRank(pUid)){
@@ -125,8 +125,8 @@ int main(int argc, char **argv) {
             checkCudaError(cudaDeviceSynchronize(), "change_gravitational_force");
 
             // Force on device was modified.
-            // So sync force before continuing the simulation. By default (no args), force is synced on both host and device
-            ac->syncForce();
+            // So sync force before continuing the simulation.
+            ac->syncForce(PairsAccessor::Host);
             std::cout << "Force [from host] after changing = (" 
                     << ac->getForce(idx)[0] << ", "
                     << ac->getForce(idx)[1] << ", " 
diff --git a/examples/modular/spring_dashpot.py b/examples/modular/spring_dashpot.py
index 4374e789347d29ba0d520987d9d625f08070a3fd..1fb833da2032b2ccf08f200a037d2f9ee8dcbc96 100644
--- a/examples/modular/spring_dashpot.py
+++ b/examples/modular/spring_dashpot.py
@@ -98,9 +98,9 @@ psim.set_domain_partitioner(pairs.block_forest())
 psim.pbc([False, False, False])
 psim.build_cell_lists(linkedCellWidth)
 
-psim.setup(update_mass_and_inertia, symbols={'infinity': math.inf })
-
-# The order of user-defined functions is not important here since they are only callable individually
+# The order of user-defined functions is not important here since 
+# they are not used by other subroutines and are only callable individually 
+psim.compute(update_mass_and_inertia, symbols={'infinity': math.inf })
 psim.compute(spring_dashpot, linkedCellWidth)
 psim.compute(euler, parameters={'dt': pairs.real()})
 psim.compute(gravity, symbols={'gravity_SI': gravity_SI })
diff --git a/examples/whole-program-generation/spring_dashpot.py b/examples/whole-program-generation/spring_dashpot.py
index f4796cfebf670efa8e9725e2daf52e2d94942c6b..397b5421bc418fc687226ac22b5e803f3566527b 100644
--- a/examples/whole-program-generation/spring_dashpot.py
+++ b/examples/whole-program-generation/spring_dashpot.py
@@ -142,9 +142,10 @@ psim.read_particle_data( "data/sd_planes.input", ['type', 'mass', 'position', 'n
 
 psim.vtk_output(f"output/dem_{target}", frequency=visSpacing)
 
+# The user-defined setup functions are executed only once before the timestep loop
 psim.setup(update_mass_and_inertia, symbols={'infinity': math.inf })
 
-# The user-defined functions are added to the timestep loop in the order they are given to 'compute'
+# The user-defined compute functions are added to the timestep loop in the order they are given to 'compute'
 psim.compute(spring_dashpot, linkedCellWidth)
 psim.compute(gravity, symbols={'gravity_SI': gravity_SI })
 psim.compute(euler, symbols={'dt': dt_SI})
diff --git a/src/pairs/code_gen/accessor.py b/src/pairs/code_gen/accessor.py
index 41b8bd099aa7c5e85ce28f546091a923dcfd911b..eb2a73dde25ec1e47b2a2ded20b3068de98ade09 100644
--- a/src/pairs/code_gen/accessor.py
+++ b/src/pairs/code_gen/accessor.py
@@ -42,7 +42,7 @@ class PairsAcessor:
         self.sync_ctx_enum()
         self.update()
         self.constructor()
-        self.destructor()
+        # self.destructor()
 
         for p in self.sim.properties:
             if (p.type()==Types.Vector) or (Types.is_scalar(p.type())):
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 313101da28f8fae613fdab6a7448355fc2381997..82b082faf522c2269745c6f87ce8a84814946469 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -149,7 +149,8 @@ class CGen:
 
         if self.target.is_gpu():
             self.print("#define PAIRS_TARGET_CUDA")
-
+            self.print("#include <math_constants.h>")
+             
         if self.target.is_openmp():
             self.print("#define PAIRS_TARGET_OPENMP")
             self.print("#include <omp.h>")
@@ -176,10 +177,14 @@ class CGen:
 
     def generate_module_header(self, module, definition=True):
         module_params = []
-        if not module.user_defined and not module.interface:
+
+        if not module.interface:
             module_params += ["PairsRuntime *pairs_runtime", "struct PairsObjects *pobj"]
 
-        if module.name=="init_domain" or module.name=="set_domain":
+        if module.name=="initialize" and self.sim.create_domain_at_initialization:
+            module_params += ["int argc", "char **argv"]
+
+        if module.name=="set_domain":
             module_params += ["int argc", "char **argv"]
 
         module_params += [f"{Types.c_keyword(self.sim, param.type())} {param.name()}" for param in module.parameters()]
@@ -194,7 +199,9 @@ class CGen:
         self.print("namespace pairs::internal {")
         self.print.add_indent(4)
 
-        for module in self.sim.modules():
+        # All modules except the interface ones are declared in the pairs::internal scope
+        for module in self.sim.modules() + self.sim.udf_modules():
+            assert not module.interface
             self.generate_module_header(module, definition=False)
         
         self.print.add_indent(-4)
@@ -307,16 +314,7 @@ class CGen:
 
         self.print.end()
 
-    def generate_library(self, 
-                         update_cells_module, 
-                         user_defined_modules, 
-                         initialize_module,
-                         set_domain_module, 
-                         setup_sim_module, 
-                         reverse_comm_module, 
-                         communicate_module, 
-                         reset_volatiles_module):
-        
+    def generate_library(self):
         self.generate_interfaces()
         # Generate CUDA/CPP file with modules
         ext = ".cu" if self.target.is_gpu() else ".cpp"
@@ -345,14 +343,14 @@ class CGen:
                     
         self.print("namespace pairs::internal {")
         self.print.add_indent(4)
-        
+
         for kernel in self.sim.kernels():
             self.generate_kernel(kernel)
 
-        for module in self.sim.modules():
-            if module.name not in ['update_cells', 'initialize', 'set_domain', 'setup_sim', 'reverse_comm', 'communicate', 'reset_volatiles']:
-                if not module.user_defined:
-                    self.generate_module(module)
+        # All modules except the interface ones are defined in the pairs::internal scope
+        for module in self.sim.modules() + self.sim.udf_modules():
+            assert not module.interface
+            self.generate_module(module)
 
         self.print.add_indent(-4)
         self.print("}")
@@ -368,106 +366,33 @@ class CGen:
         self.generate_pairs_object_structure()
         self.generate_module_decls()
 
-        ndims = module.sim.ndims()
-        nprops = module.sim.properties.nprops()
-        ncontactprops = module.sim.contact_properties.nprops()
-        narrays = module.sim.arrays.narrays()
-        part = DomainPartitioners.c_keyword(module.sim.partitioner())
-
         self.generate_full_object_names = True
         self.print("class PairsSimulation {")
         self.print("private:")
         self.print("    PairsRuntime *pairs_runtime;")
         self.print("    struct PairsObjects *pobj;")
         self.print("    friend class PairsAccessor;")
+        self.print("")
         self.print("public:")
-
         self.print.add_indent(4)
 
-        for module in user_defined_modules:
+        # Only interface modules are generated in the PairsSimulation class
+        for module in self.sim.interface_modules():
             self.generate_module(module)
-            self.print("")
-        
-        if set_domain_module:
-            self.print("void initialize() {")
-        else:
-            self.print("void initialize(int argc, char **argv) {")
-        self.print(f"    pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
-        self.print(f"    pobj = new PairsObjects();")
-        self.generate_statement(initialize_module.block)
-        self.print("}")
-        self.print("")
 
-        if set_domain_module:
-            self.generate_module(set_domain_module)
+        # Generate a 'use_domain' module only if domain is not predefined in the input script
+        if not self.sim.create_domain_at_initialization:
+            self.print("template<typename Domain_T>")
+            self.print("void use_domain(const std::shared_ptr<Domain_T> &domain_ptr) {")
+            self.print("    pairs_runtime->useDomain(domain_ptr);")
+            self.print("}")
             self.print("")
 
-        self.print("template<typename Domain_T>")
-        self.print("void use_domain(const std::shared_ptr<Domain_T> &domain_ptr) {")
-        self.print("    pairs_runtime->useDomain(domain_ptr);")
-        self.print("}")
-        self.print("")
-
-        self.print("pairs::id_t create_halfspace(double x, double y, double z, double nx, double ny, double nz, int type, int flag){")
-        self.print("    return pairs::create_halfspace(pairs_runtime, x, y, z, nx, ny, nz, type, flag);")
-        self.print("}")
-        self.print("")
-
-        self.print("pairs::id_t create_sphere(double x, double y, double z, double vx, double vy, double vz, double density, double radius, int type, int flag){")
-        self.print("    return pairs::create_sphere(pairs_runtime, x, y, z, vx, vy, vz, density, radius, type, flag);")
-        self.print("}")
-        self.print("")
-
-        self.print("int rank(){ return pairs_runtime->getDomainPartitioner()->getRank();}")
-        self.print("")
-
-        self.print("int size(){ return pobj->nlocal + pobj->nghost;}")
-        self.print("")
-
-        self.print("int nlocal(){ return pobj->nlocal;}")
-        self.print("")
-
-        self.print("int nghost(){ return pobj->nghost;}")
-        self.print("")
-
-        self.print("void setup_sim() {")
-        self.generate_statement(setup_sim_module.block)
-        self.print("}")
-        self.print("")
-
-        self.print("void update_cells() {")
-        self.generate_statement(update_cells_module.block)
-        self.print("}")
-        self.print("")
-
-        self.print("void reverse_comm() {")
-        self.generate_statement(reverse_comm_module.block)
-        self.print("}")
-        self.print("")
-
-        self.print("void communicate(int timestep = 0) {")
-        self.print("    pobj->sim_timestep = timestep;")
-        self.generate_statement(communicate_module.block)
-        self.print("}")
-        self.print("")
-
-        self.print("void reset_volatiles() {")
-        self.generate_statement(reset_volatiles_module.block)
-        self.print("}")
-        self.print("")
-
         self.print("void vtk_write(const char* filename, int start, int end, int timestep, int frequency) {")
         self.print("    pairs::vtk_write_data(pairs_runtime, filename, start, end, timestep, frequency);")
         self.print("}")
         self.print("")
 
-        self.print("void end() {")
-        self.print("    pairs::print_timers(pairs_runtime);")
-        self.print("    pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);")
-        self.print("    delete pobj;")
-        self.print("    delete pairs_runtime;")
-        self.print("}")
-
         self.print.add_indent(-4)
         self.print("};")
 
@@ -569,12 +494,13 @@ class CGen:
         if self.debug:
             self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
 
-        if not module.user_defined and not module.interface:
+        if not module.interface:
             self.generate_module_declerations(module)
 
         self.print.add_indent(-4)
         self.generate_statement(module.block)
         self.print("}")
+        self.print("")
 
     def generate_kernel(self, kernel):
         kernel_params = "int range_start"
@@ -1228,7 +1154,10 @@ class CGen:
                 return ast_node.value[index]
 
             if isinstance(ast_node.value, float) and math.isinf(ast_node.value):
-                return f"std::numeric_limits<{self.real_type()}>::infinity()"
+                if self.kernel_context:
+                    return "CUDART_INF"
+                else:
+                    return f"std::numeric_limits<{self.real_type()}>::infinity()"
 
             return ast_node.value
 
diff --git a/src/pairs/code_gen/interface.py b/src/pairs/code_gen/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e52deed7d6386fe4cd054188e4868773a014e6c
--- /dev/null
+++ b/src/pairs/code_gen/interface.py
@@ -0,0 +1,251 @@
+from pairs.ir.block import Block, pairs_interface_block
+from pairs.ir.functions import Call_Void, Call, Call_Int
+from pairs.ir.parameters import Parameter
+from pairs.ir.ret import Return
+from pairs.ir.scalars import ScalarOp
+from pairs.sim.domain import UpdateDomain, SetDomain
+from pairs.sim.cell_lists import BuildCellListsStencil
+from pairs.sim.comm import Synchronize, Borders, Exchange, ReverseComm
+from pairs.ir.types import Types
+from pairs.ir.branches import Filter, Branch
+from pairs.sim.cell_lists import BuildCellLists, BuildCellListsStencil, PartitionCellLists, BuildCellNeighborLists
+from pairs.sim.neighbor_lists import BuildNeighborLists
+from pairs.sim.variables import DeclareVariables 
+from pairs.sim.arrays import DeclareArrays
+from pairs.sim.properties import AllocateProperties, AllocateContactProperties, ResetVolatileProperties
+from pairs.sim.features import AllocateFeatureProperties
+from pairs.sim.instrumentation import RegisterMarkers, RegisterTimers
+from pairs.sim.grid import MutableGrid
+from pairs.sim.domain_partitioners import DomainPartitioners
+from pairs.ir.print import PrintCode
+from pairs.ir.assign import Assign
+from pairs.sim.contact_history import BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus
+from pairs.sim.thermo import ComputeThermo
+
+class InterfaceModules:
+    def __init__(self, sim):
+        self.sim = sim
+
+    def create_all(self):
+        self.initialize()
+
+        # Generate a 'set_domain' module only if domain is not pre-set in the input script
+        if not self.sim.create_domain_at_initialization:
+            self.set_domain()
+
+        self.setup_sim()
+        self.update_cells(self.sim.reneighbor_frequency) 
+        self.communicate(self.sim.reneighbor_frequency)
+        self.reverse_comm() 
+        self.reset_volatiles()
+
+        if self.sim._use_contact_history:
+            if self.neighbor_lists:
+                self.build_contact_history(self.sim.reneighbor_frequency)
+            self.reset_contact_history()
+
+        if self.sim._compute_thermo != 0:
+            self.compute_thermo(self.sim._compute_thermo)
+
+        self.rank()
+        self.nlocal()
+        self.nghost()
+        self.size()
+        self.create_sphere()
+        self.create_halfspace()
+        self.dem_sc_grid()
+        self.end()      
+
+    @pairs_interface_block
+    def initialize(self):
+        self.sim.module_name('initialize')
+        nprops = self.sim.properties.nprops()
+        ncontactprops = self.sim.contact_properties.nprops()
+        narrays = self.sim.arrays.narrays()
+        part = DomainPartitioners.c_keyword(self.sim.partitioner())
+
+        PrintCode(self.sim, f"pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
+        PrintCode(self.sim, f"pobj = new PairsObjects();")
+
+        inits = Block.from_list(self.sim, [
+            DeclareVariables(self.sim),
+            DeclareArrays(self.sim),
+            AllocateProperties(self.sim),
+            AllocateContactProperties(self.sim),
+            AllocateFeatureProperties(self.sim),
+            RegisterTimers(self.sim),
+            RegisterMarkers(self.sim)
+        ])
+
+        if self.sim.create_domain_at_initialization:
+            self.sim.add_statement(Block.merge_blocks(inits, self.sim.create_domain))
+        else:
+            assert self.sim.grid is None, "A grid already exists"
+            self.sim.grid = MutableGrid(self.sim, self.sim.dims)
+            self.sim.add_statement(inits)
+
+    @pairs_interface_block
+    def set_domain(self):
+        assert isinstance(self.sim.grid, MutableGrid)
+        self.sim.module_name('set_domain')
+        self.sim.add_statement(SetDomain(self.sim))
+
+    @pairs_interface_block
+    def setup_sim(self):
+        self.sim.module_name('setup_sim')
+        self.sim.add_statement(self.sim.setup_particles)
+        self.sim.add_statement(UpdateDomain(self.sim))
+        self.sim.add_statement(BuildCellListsStencil(self.sim, self.sim.cell_lists))
+    
+    @pairs_interface_block
+    def reset_volatiles(self):
+        self.sim.module_name('reset_volatiles')
+        self.sim.add_statement(ResetVolatileProperties(self.sim))
+    
+    @pairs_interface_block
+    def update_cells(self, reneighbor_frequency=1):
+        self.sim.module_name('update_cells')
+        timestep = Parameter(self.sim, f'timestep', Types.Int32)
+        cond = ScalarOp.inline(ScalarOp.or_op(
+            ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
+            ScalarOp.cmp(timestep, 0)
+            ))
+        
+        subroutines = [BuildCellLists(self.sim, self.sim.cell_lists),
+                  PartitionCellLists(self.sim, self.sim.cell_lists)]
+        
+        # Add routine to build neighbor-lists per cell
+        if self.sim._store_neighbors_per_cell:
+            subroutines.append(BuildCellNeighborLists(self.sim, self.sim.cell_lists))
+
+        # Add routine to build neighbor-lists per particle (standard Verlet Lists)
+        if self.sim.neighbor_lists:
+            subroutines.append(BuildNeighborLists(self.sim, self.sim.neighbor_lists))
+
+        self.sim.add_statement(Filter(self.sim, cond, Block.from_list(self.sim, subroutines)))
+
+    @pairs_interface_block
+    def communicate(self, reneighbor_frequency=1):
+        self.sim.module_name('communicate')
+        timestep = Parameter(self.sim, f'timestep', Types.Int32)
+        cond = ScalarOp.inline(ScalarOp.or_op(
+            ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
+            ScalarOp.cmp(timestep, 0)
+            ))
+        
+        exchange = Filter(self.sim, cond, Exchange(self.sim._comm))
+        border_sync = Branch(self.sim, cond, blk_if = Borders(self.sim._comm), 
+                             blk_else = Synchronize(self.sim._comm))
+        
+        self.sim.add_statement(exchange)
+        self.sim.add_statement(border_sync)
+
+    @pairs_interface_block
+    def reverse_comm(self):
+        self.sim.module_name('reverse_comm')
+        self.sim.add_statement(ReverseComm(self.sim._comm, reduce=True))
+    
+    @pairs_interface_block
+    def build_contact_history(self, reneighbor_frequency=1):
+        self.sim.module_name('build_contact_history')
+        timestep = Parameter(self.sim, f'timestep', Types.Int32)
+        cond = ScalarOp.inline(ScalarOp.or_op(
+            ScalarOp.cmp((timestep + 1) % reneighbor_frequency, 0),
+            ScalarOp.cmp(timestep, 0)
+            ))
+        
+        self.sim.add_statement(
+            Filter(self.sim, cond,
+                   BuildContactHistory(self.sim, self.sim._contact_history, self.sim.cell_lists)))
+
+    @pairs_interface_block
+    def reset_contact_history(self):
+        self.sim.module_name('reset_contact_history')
+        self.sim.add_statement(ResetContactHistoryUsageStatus(self.sim, self.sim._contact_history))
+        self.sim.add_statement(ClearUnusedContactHistory(self.sim, self.sim._contact_history))
+
+    @pairs_interface_block
+    def compute_thermo(self):
+        self.sim.module_name('compute_thermo')
+        self.sim.add_statement(ComputeThermo(self.sim))
+
+    @pairs_interface_block
+    def rank(self):
+        self.sim.module_name('rank')
+        Return(self.sim, self.sim.domain_partitioning().rank)
+
+    @pairs_interface_block
+    def nlocal(self):
+        self.sim.module_name('nlocal')
+        Return(self.sim, self.sim.nlocal)
+
+    @pairs_interface_block
+    def nghost(self):
+        self.sim.module_name('nghost')
+        Return(self.sim, self.sim.nghost)
+
+    @pairs_interface_block
+    def size(self):
+        self.sim.module_name('size')
+        Return(self.sim, ScalarOp.inline(self.sim.nlocal + self.sim.nghost))
+
+    @pairs_interface_block
+    def create_sphere(self):
+        self.sim.module_name('create_sphere')
+        x = Parameter(self.sim, 'x', Types.Real)
+        y = Parameter(self.sim, 'y', Types.Real)
+        z = Parameter(self.sim, 'z', Types.Real)
+        vx = Parameter(self.sim, 'vx', Types.Real)
+        vy = Parameter(self.sim, 'vy', Types.Real)
+        vz = Parameter(self.sim, 'vz', Types.Real)
+        density = Parameter(self.sim, 'density', Types.Real)
+        radius = Parameter(self.sim, 'radius', Types.Real)
+        ptype = Parameter(self.sim, 'type', Types.Real)
+        flag = Parameter(self.sim, 'flag', Types.Real)
+
+        Return(self.sim, Call(self.sim, "pairs::create_sphere", 
+                              [x, y, z, vx, vy, vz, 
+                               density, radius, ptype, flag], Types.UInt64))
+
+    @pairs_interface_block
+    def create_halfspace(self):
+        self.sim.module_name('create_halfspace')
+        x = Parameter(self.sim, 'x', Types.Real)
+        y = Parameter(self.sim, 'y', Types.Real)
+        z = Parameter(self.sim, 'z', Types.Real)
+        nx = Parameter(self.sim, 'nx', Types.Real)
+        ny = Parameter(self.sim, 'ny', Types.Real)
+        nz = Parameter(self.sim, 'nz', Types.Real)
+        ptype = Parameter(self.sim, 'type', Types.Real)
+        flag = Parameter(self.sim, 'flag', Types.Real)
+
+        Return(self.sim, Call(self.sim, "pairs::create_halfspace", 
+                              [x, y, z, nx, ny, nz, ptype, flag], Types.UInt64))
+        
+    @pairs_interface_block
+    def dem_sc_grid(self):
+        self.sim.module_name('dem_sc_grid')
+        xmax = Parameter(self.sim, 'xmax', Types.Real)
+        ymax = Parameter(self.sim, 'ymax', Types.Real)
+        zmax = Parameter(self.sim, 'zmax', Types.Real)
+        spacing = Parameter(self.sim, 'spacing', Types.Real)
+        diameter = Parameter(self.sim, 'diameter', Types.Real)
+        min_diameter = Parameter(self.sim, 'min_diameter', Types.Real)
+        max_diameter = Parameter(self.sim, 'max_diameter', Types.Real)
+        initial_velocity = Parameter(self.sim, 'initial_velocity', Types.Real)
+        particle_density = Parameter(self.sim, 'particle_density', Types.Real)
+        ntypes = Parameter(self.sim, 'ntypes', Types.Int32)
+
+        Assign(self.sim, self.sim.nlocal,
+               Call_Int(self.sim, "pairs::dem_sc_grid",
+                        [xmax, ymax, zmax, spacing, diameter, min_diameter, max_diameter,
+                         initial_velocity, particle_density, ntypes]))
+        Return(self.sim, self.sim.nlocal)
+
+    @pairs_interface_block
+    def end(self):
+        self.sim.module_name('end')
+        Call_Void(self.sim, "pairs::print_timers", [])
+        Call_Void(self.sim, "pairs::print_stats", [self.sim.nlocal, self.sim.nghost])
+        PrintCode(self.sim, "delete pobj;")
+        PrintCode(self.sim, "delete pairs_runtime;")
diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index 1a0809b811840a6f0d0edbccfd1fb68a9712992f..2a14ea2776dcf3651a8f9d03b397bf9db1bd1fb6 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -42,6 +42,21 @@ def pairs_device_block(func):
     return inner
 
 
+def pairs_interface_block(func):
+    def inner(*args, **kwargs):
+        sim = args[0].sim # self.sim
+        sim.init_block()
+        func(*args, **kwargs)
+        return Module(sim,
+            name=sim._module_name,
+            block=Block(sim, sim._block),
+            resizes_to_check=sim._resizes_to_check,
+            check_properties_resize=sim._check_properties_resize,
+            run_on_device=False,
+            interface=True)
+
+    return inner
+
 class Block(ASTNode):
     def __init__(self, sim, stmts):
         super().__init__(sim)
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index 44539dba0eab376716cfd4920818889b81e85690..ded67ac6f4448590c346f51177b17fe364b729d0 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -37,7 +37,17 @@ class Module(ASTNode):
         self._interface = interface
         self._return_type = Types.Void
         self._profile = False
-        sim.add_module(self)
+
+        if user_defined:
+            assert not interface, ("User-defined modules can't be part of the interface directly."
+                                "Wrap them inside seperate interface modules.")
+            sim.add_udf_module(self)
+        else:
+            if interface:
+                sim.add_interface_module(self)
+            else:
+                sim.add_module(self)
+                
         Module.last_module += 1
 
     def __str__(self):
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index 16abbab7d5d9181184914440c91bd1e86e3932a6..b4ec2f8f5fcb083a92f3eb5a43f6f33cc49093a5 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -335,12 +335,13 @@ def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=F
 
             ir.visit(tree)
 
-    if pre_step:
-        sim.build_pre_step_module_with_statements(skip_first=skip_first, profile=True)
-
+    if sim._generate_whole_program:
+        if pre_step:
+            sim.build_pre_step_module_with_statements(skip_first=skip_first, profile=True)
+        else:
+            sim.build_module_with_statements(skip_first=skip_first, profile=True)
     else:
-        sim.build_module_with_statements(skip_first=skip_first, profile=True)
-
+        sim.build_user_defined_function()
 
 def setup(sim, func, symbols={}):
     src = inspect.getsource(func)
@@ -366,4 +367,8 @@ def setup(sim, func, symbols={}):
         ir.add_symbols({params[0]: i})
         ir.visit(tree)
 
-    sim.build_setup_module_with_statements()
+    if sim._generate_whole_program:
+        sim.build_setup_module_with_statements()
+    else:
+        sim.build_user_defined_function()
+    
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 05081631d02825a11da2a417b64cb586d543f6aa..6ec01799f248461c11927d3e8ddcbc81ffe1ac18 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -49,47 +49,31 @@ class Comm:
             self.recv_offsets_reverse     = sim.add_array('recv_offsets_reverse', [dom_part.nranks_capacity], Types.Int32)
             self.recv_buffer_reverse      = sim.add_array('recv_buffer_reverse', [self.recv_capacity, self.elem_capacity], Types.Real, arr_sync=False)
 
+
+class Synchronize(Lowerable):
+    def __init__(self, comm):
+        self.sim = comm.sim
+        self.comm = comm
+
     @pairs_inline
-    def synchronize(self):
+    def lower(self):
         # Every property that is not constant across timesteps and have neighbor accesses during any
         # interaction kernel (i.e. property[j] in force calculation kernel)
         prop_names = ['position', 'linear_velocity', 'angular_velocity']
         prop_list = [self.sim.property(p) for p in prop_names if self.sim.property(p) is not None]
 
-        PackAllGhostParticles(self, prop_list)
-        CommunicateAllData(self, prop_list)
-        UnpackAllGhostParticles(self, prop_list)
-        
-    @pairs_host_block
-    def reverse_comm(self, reduce=False):
-        self.sim.module_name(f"reverse_comm")
-        prop_list = self.sim.properties.reduction_props()
+        PackAllGhostParticles(self.comm, prop_list)
+        CommunicateAllData(self.comm, prop_list)
+        UnpackAllGhostParticles(self.comm, prop_list)
 
-        if prop_list :
-            for step in range(self.dom_part.number_of_steps() - 1, -1, -1):
-                if self.sim._target.is_gpu():
-                    CopyArray(self.sim, self.nsend, Contexts.Host, Actions.ReadOnly)
-                    CopyArray(self.sim, self.nrecv, Contexts.Host, Actions.ReadOnly)
-                    CopyArray(self.sim, self.send_offsets, Contexts.Host, Actions.ReadOnly)
-                    CopyArray(self.sim, self.recv_offsets, Contexts.Host, Actions.ReadOnly)
-
-                    CopyArray(self.sim, self.nsend_reverse, Contexts.Host, Actions.WriteOnly)
-                    CopyArray(self.sim, self.nrecv_reverse, Contexts.Host, Actions.WriteOnly)
-                    CopyArray(self.sim, self.send_offsets_reverse, Contexts.Host, Actions.WriteOnly)
-                    CopyArray(self.sim, self.recv_offsets_reverse, Contexts.Host, Actions.WriteOnly)
-                
-                for j in self.dom_part.step_indexes(step):
-                    Assign(self.sim, self.nsend_reverse[j], self.nrecv[j])
-                    Assign(self.sim, self.nrecv_reverse[j], self.nsend[j])
-                    Assign(self.sim, self.send_offsets_reverse[j], self.recv_offsets[j])
-                    Assign(self.sim, self.recv_offsets_reverse[j], self.send_offsets[j])
 
-                PackGhostParticlesReverse(self, step, prop_list)
-                CommunicateDataReverse(self, step, prop_list)
-                UnpackGhostParticlesReverse(self, step, prop_list, reduce)
+class Borders(Lowerable):
+    def __init__(self, comm):
+        self.sim = comm.sim
+        self.comm = comm
 
     @pairs_inline
-    def borders(self):
+    def lower(self):
         # Every property that has neighbor accesses during any interaction kernel (i.e. property[j]
         # exists in any force calculation kernel)
         # We ignore normal because there should be no ghost half-spaces
@@ -106,84 +90,128 @@ class Comm:
 
         prop_list = [self.sim.property(p) for p in prop_names if self.sim.property(p) is not None]
 
-        Assign(self.sim, self.nsend_all, 0)
+        Assign(self.sim, self.comm.nsend_all, 0)
         Assign(self.sim, self.sim.nghost, 0)
 
-        for step in range(self.dom_part.number_of_steps()):
+        for step in range(self.comm.dom_part.number_of_steps()):
             if self.sim._target.is_gpu():
-                CopyArray(self.sim, self.nsend, Contexts.Host, Actions.Ignore)
-                CopyArray(self.sim, self.nrecv, Contexts.Host, Actions.Ignore)
+                CopyArray(self.sim, self.comm.nsend, Contexts.Host, Actions.Ignore)
+                CopyArray(self.sim, self.comm.nrecv, Contexts.Host, Actions.Ignore)
 
-            for j in self.dom_part.step_indexes(step):
-                Assign(self.sim, self.nsend[j], 0)
-                Assign(self.sim, self.nrecv[j], 0)
+            for j in self.comm.dom_part.step_indexes(step):
+                Assign(self.sim, self.comm.nsend[j], 0)
+                Assign(self.sim, self.comm.nrecv[j], 0)
 
             if self.sim._target.is_gpu():
-                CopyArray(self.sim, self.nsend, Contexts.Device, Actions.Ignore)
-                CopyArray(self.sim, self.nrecv, Contexts.Device, Actions.Ignore)
+                CopyArray(self.sim, self.comm.nsend, Contexts.Device, Actions.Ignore)
+                CopyArray(self.sim, self.comm.nrecv, Contexts.Device, Actions.Ignore)
 
-            DetermineGhostParticles(self, step, self.sim.cell_spacing())
-            CommunicateSizes(self, step)
-            SetCommunicationOffsets(self, step)
-            PackGhostParticles(self, step, prop_list)
-            CommunicateData(self, step, prop_list)
-            UnpackGhostParticles(self, step, prop_list)
+            DetermineGhostParticles(self.comm, step, self.sim.cell_spacing())
+            CommunicateSizes(self.comm, step)
+            SetCommunicationOffsets(self.comm, step)
+            PackGhostParticles(self.comm, step, prop_list)
+            CommunicateData(self.comm, step, prop_list)
+            UnpackGhostParticles(self.comm, step, prop_list)
 
-            step_nrecv = self.dom_part.reduce_sum_step_indexes(step, self.nrecv)
+            step_nrecv = self.comm.dom_part.reduce_sum_step_indexes(step, self.comm.nrecv)
             Assign(self.sim, self.sim.nghost, self.sim.nghost + step_nrecv)
 
+
+class Exchange(Lowerable):
+    def __init__(self, comm):
+        self.sim = comm.sim
+        self.comm = comm
+
     @pairs_inline
-    def exchange(self):
+    def lower(self):
         # Every property except volatiles
         prop_list = self.sim.properties.non_volatiles()
 
-        for step in range(self.dom_part.number_of_steps()):
-            Assign(self.sim, self.nsend_all, 0)
-            Assign(self.sim, self.sim.nghost, 0)
+        for step in range(self.comm.dom_part.number_of_steps()):
+            Assign(self.comm.sim, self.comm.nsend_all, 0)
+            Assign(self.comm.sim, self.sim.nghost, 0)
 
             for s in range(step + 1):
-                for j in self.dom_part.step_indexes(s):
-                    Assign(self.sim, self.nsend[j], 0)
-                    Assign(self.sim, self.nrecv[j], 0)
-                    Assign(self.sim, self.send_offsets[j], 0)
-                    Assign(self.sim, self.recv_offsets[j], 0)
-                    Assign(self.sim, self.nsend_contact[j], 0)
-                    Assign(self.sim, self.nrecv_contact[j], 0)
-                    Assign(self.sim, self.contact_soffsets[j], 0)
-                    Assign(self.sim, self.contact_soffsets[j], 0)
+                for j in self.comm.dom_part.step_indexes(s):
+                    Assign(self.comm.sim, self.comm.nsend[j], 0)
+                    Assign(self.comm.sim, self.comm.nrecv[j], 0)
+                    Assign(self.comm.sim, self.comm.send_offsets[j], 0)
+                    Assign(self.comm.sim, self.comm.recv_offsets[j], 0)
+                    Assign(self.comm.sim, self.comm.nsend_contact[j], 0)
+                    Assign(self.comm.sim, self.comm.nrecv_contact[j], 0)
+                    Assign(self.comm.sim, self.comm.contact_soffsets[j], 0)
+                    Assign(self.comm.sim, self.comm.contact_soffsets[j], 0)
 
             if self.sim._target.is_gpu():
-                CopyArray(self.sim, self.nsend, Contexts.Device, Actions.Ignore)
-                CopyArray(self.sim, self.nrecv, Contexts.Device, Actions.Ignore)
+                CopyArray(self.comm.sim, self.comm.nsend, Contexts.Device, Actions.Ignore)
+                CopyArray(self.comm.sim, self.comm.nrecv, Contexts.Device, Actions.Ignore)
 
-            DetermineGhostParticles(self, step, 0.0)
-            CommunicateSizes(self, step)
-            SetCommunicationOffsets(self, step)
-            PackGhostParticles(self, step, prop_list)
+            DetermineGhostParticles(self.comm, step, 0.0)
+            CommunicateSizes(self.comm, step)
+            SetCommunicationOffsets(self.comm, step)
+            PackGhostParticles(self.comm, step, prop_list)
 
             if self.sim._target.is_gpu():
-                send_map_size = self.nsend_all * Sizeof(self.sim, Types.Int32)
-                exchg_flag_size = self.sim.nlocal * Sizeof(self.sim, Types.Int32)
-                CopyArray(self.sim, self.send_map, Contexts.Host, Actions.ReadOnly, send_map_size)
-                CopyArray(self.sim, self.exchg_flag, Contexts.Host, Actions.ReadOnly, exchg_flag_size)
+                send_map_size = self.comm.nsend_all * Sizeof(self.comm.sim, Types.Int32)
+                exchg_flag_size = self.sim.nlocal * Sizeof(self.comm.sim, Types.Int32)
+                CopyArray(self.comm.sim, self.comm.send_map, Contexts.Host, Actions.ReadOnly, send_map_size)
+                CopyArray(self.comm.sim, self.comm.exchg_flag, Contexts.Host, Actions.ReadOnly, exchg_flag_size)
 
-            RemoveExchangedParticles_part1(self)
+            RemoveExchangedParticles_part1(self.comm)
 
             if self.sim._target.is_gpu():
-                exchg_copy_to_size = self.nsend_all * Sizeof(self.sim, Types.Int32)
+                exchg_copy_to_size = self.comm.nsend_all * Sizeof(self.comm.sim, Types.Int32)
                 CopyArray(
-                    self.sim, self.exchg_copy_to, Contexts.Device, Actions.ReadOnly, exchg_copy_to_size)
+                    self.comm.sim, self.comm.exchg_copy_to, Contexts.Device, Actions.ReadOnly, exchg_copy_to_size)
 
-            RemoveExchangedParticles_part2(self, prop_list)
-            CommunicateData(self, step, prop_list)
-            UnpackGhostParticles(self, step, prop_list)
+            RemoveExchangedParticles_part2(self.comm, prop_list)
+            CommunicateData(self.comm, step, prop_list)
+            UnpackGhostParticles(self.comm, step, prop_list)
 
             if self.sim._use_contact_history:
-                PackContactHistoryData(self, step)
-                CommunicateContactHistoryData(self, step)
-                UnpackContactHistoryData(self, step)
+                PackContactHistoryData(self.comm, step)
+                CommunicateContactHistoryData(self.comm, step)
+                UnpackContactHistoryData(self.comm, step)
+
+            ChangeSizeAfterExchange(self.comm, step)
+
+
+class ReverseComm(Lowerable):
+    def __init__(self, comm, reduce=False):
+        self.sim = comm.sim
+        self.comm = comm
+        self.reduce = reduce
+
+    @pairs_inline
+    def lower(self):
+        prop_list = self.sim.properties.reduction_props()
+
+        if prop_list :
+            for step in range(self.comm.dom_part.number_of_steps() - 1, -1, -1):
+                if self.sim._target.is_gpu():
+                    CopyArray(self.sim, self.comm.nsend, Contexts.Host, Actions.ReadOnly)
+                    CopyArray(self.sim, self.comm.nrecv, Contexts.Host, Actions.ReadOnly)
+                    CopyArray(self.sim, self.comm.send_offsets, Contexts.Host, Actions.ReadOnly)
+                    CopyArray(self.sim, self.comm.recv_offsets, Contexts.Host, Actions.ReadOnly)
+
+                    CopyArray(self.sim, self.comm.nsend_reverse, Contexts.Host, Actions.WriteOnly)
+                    CopyArray(self.sim, self.comm.nrecv_reverse, Contexts.Host, Actions.WriteOnly)
+                    CopyArray(self.sim, self.comm.send_offsets_reverse, Contexts.Host, Actions.WriteOnly)
+                    CopyArray(self.sim, self.comm.recv_offsets_reverse, Contexts.Host, Actions.WriteOnly)
+                
+                for j in self.comm.dom_part.step_indexes(step):
+                    Assign(self.sim, self.comm.nsend_reverse[j], self.comm.nrecv[j])
+                    Assign(self.sim, self.comm.nrecv_reverse[j], self.comm.nsend[j])
+                    Assign(self.sim, self.comm.send_offsets_reverse[j], self.comm.recv_offsets[j])
+                    Assign(self.sim, self.comm.recv_offsets_reverse[j], self.comm.send_offsets[j])
+
+                PackGhostParticlesReverse(self.comm, step, prop_list)
+                CommunicateDataReverse(self.comm, step, prop_list)
+                UnpackGhostParticlesReverse(self.comm, step, prop_list, self.reduce)
+
+
+
 
-            ChangeSizeAfterExchange(self, step)
 
 
 class CommunicateSizes(Lowerable):
diff --git a/src/pairs/sim/domain.py b/src/pairs/sim/domain.py
index 559384024a6999221814f086714dde8bca114374..9a85792a7b31412172540ed96152b51b175b3fb0 100644
--- a/src/pairs/sim/domain.py
+++ b/src/pairs/sim/domain.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_inline, pairs_host_block
+from pairs.ir.block import pairs_inline
 from pairs.ir.parameters import Parameter
 from pairs.ir.types import Types
 from pairs.ir.assign import Assign
@@ -17,10 +17,8 @@ class SetDomain(Lowerable):
     def __init__(self, sim):
         super().__init__(sim)
 
-    @pairs_host_block
+    @pairs_inline
     def lower(self):
-        self.sim.module_name('init_domain')
-
         for d in range(self.sim.ndims()):
             dmin = Parameter(self.sim, f'd{d}_min', Types.Real)
             Assign(self.sim, self.sim.grid.min(d), dmin)
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 89cfb5925224de67d0a0580238f6bdf2a9b785a2..f58082ea8d592351f4833d82594b0423b2906401 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -13,7 +13,7 @@ from pairs.ir.variables import Variables
 from pairs.mapping.funcs import compute, setup
 from pairs.sim.arrays import DeclareArrays
 from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists, BuildCellNeighborLists
-from pairs.sim.comm import Comm
+from pairs.sim.comm import Comm, Synchronize, Borders, Exchange, ReverseComm
 from pairs.sim.contact_history import ContactHistory, BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus
 from pairs.sim.copper_fcc_lattice import CopperFCCLattice
 from pairs.sim.dem_sc_grid import DEMSCGrid
@@ -32,6 +32,7 @@ from pairs.sim.timestep import Timestep
 from pairs.sim.variables import DeclareVariables 
 from pairs.sim.vtk import VTKWrite
 from pairs.transformations import Transformations
+from pairs.code_gen.interface import InterfaceModules
 
 
 class Simulation:
@@ -92,13 +93,17 @@ class Simulation:
 
         # Different segments of particle code/functions
         self.create_domain = Block(self, [])
-        self.generate_set_domain_module = True
+        self.create_domain_at_initialization = False
 
         self.setup_particles = Block(self, [])
         self.module_list = []
         self.kernel_list = []
 
-        # User-defined modules
+        # Individual user-defined and interface modules are created only when generate_whole_program is False
+        self.udf_module_list = []
+        self.interface_module_list = []
+
+        # User-defined functions to be called by other subroutines (used only when generate_whole_program is True)
         self.setup_functions = []
         self.pre_step_functions = []
         self.functions = []
@@ -114,6 +119,7 @@ class Simulation:
         # Domain partitioning
         self._dom_part = None
         self._partitioner = None
+        self._comm = None
 
         # Contact history
         self._use_contact_history = use_contact_history
@@ -164,11 +170,30 @@ class Simulation:
     def max_shapes(self):
         return len(self._shapes)
 
+    def add_udf_module(self, module):
+        assert isinstance(module, Module), "add_udf_module(): Given parameter is not of type Module!"
+        assert module.user_defined and not module.interface
+        if module.name not in [m.name for m in self.udf_module_list]:
+            self.udf_module_list.append(module)
+
+    def add_interface_module(self, module):
+        assert isinstance(module, Module), "add_interface_module(): Given parameter is not of type Module!"
+        assert module.interface and not module.user_defined
+        if module.name not in [m.name for m in self.interface_module_list]:
+            self.interface_module_list.append(module)
+
     def add_module(self, module):
         assert isinstance(module, Module), "add_module(): Given parameter is not of type Module!"
+        assert not module.interface and not module.user_defined
         if module.name not in [m.name for m in self.module_list]:
             self.module_list.append(module)
 
+    def interface_modules(self):
+        return self.interface_module_list
+    
+    def udf_modules(self):
+        return self.udf_module_list
+    
     def modules(self):
         """List simulation modules, with main always in the last position"""
 
@@ -279,7 +304,7 @@ class Simulation:
         If the domain is set through this function, the 'set_domain' module won't be generated in the modular version.
         Use this function only if you do not need to set domain at runtime.
         This function is required only for whole-program generation."""
-        self.generate_set_domain_module = False
+        self.create_domain_at_initialization = True
         self.grid = Grid3D(self, grid[0], grid[1], grid[2], grid[3], grid[4], grid[5])
         self.create_domain.add_statement(InitializeDomain(self))
 
@@ -357,7 +382,7 @@ class Simulation:
                 block=Block(self, self._block),
                 resizes_to_check=self._resizes_to_check,
                 check_properties_resize=self._check_properties_resize,
-                run_on_device=False))
+                run_on_device=True))
 
     def build_pre_step_module_with_statements(self, run_on_device=True, skip_first=False, profile=False):
         """Build a Module in the pre-step part of the program using the last initialized block"""
@@ -392,6 +417,16 @@ class Simulation:
         else:
             self.functions.append(module)
 
+    def build_user_defined_function(self, run_on_device=True):
+        """Build a user-defined Module that will be callable seperately as part of the interface"""
+        Module(self, name=self._module_name,
+                block=Block(self, self._block),
+                resizes_to_check=self._resizes_to_check,
+                check_properties_resize=self._check_properties_resize,
+                run_on_device=run_on_device,
+                user_defined=True)
+        
+
     def capture_statements(self, capture=True):
         """When toggled, all constructed statements are captured and automatically added to the last initialized block"""
         self._capture_statements = capture
@@ -454,12 +489,37 @@ class Simulation:
 
     def generate(self):
         """Generate the code for the simulation"""
-
         assert self._target is not None, "Target not specified!"
 
-        # Initialize communication instance with specified domain-partitioner
-        comm = Comm(self, self._dom_part)
-        reverse_comm_module = comm.reverse_comm(reduce=True)
+        # Initialize communication instance with the specified domain-partitioner
+        self._comm = Comm(self, self._dom_part)
+        
+        if self._generate_whole_program:
+            self.generate_program()
+        else:
+            self.generate_library()
+
+    def generate_library(self):
+        InterfaceModules(self).create_all()
+        
+        # User defined functions are wrapped inside seperate interface modules here.
+        # The udf's have the same name as their interface module but they get implemented in the pairs::internal scope.
+        for m in self.udf_module_list:
+            module = Module(self, name=m.name, block=Block(self, m), interface=True)
+            module._id = m._id
+
+        Transformations(self.interface_modules(), self._target).apply_all()
+
+        # Generate library
+        self.code_gen.generate_library()
+
+        # Generate getters for the runtime functions
+        self.code_gen.generate_interfaces()
+
+    def generate_program(self):
+        assert self.grid, "No domain is created. Set domain bounds with 'set_domain'."
+
+        reverse_comm_module = ReverseComm(self._comm, reduce=True)
 
         # Params that determine when a method must be called only when reneighboring
         every_reneighbor_params = {'every': self.reneighbor_frequency}
@@ -470,8 +530,8 @@ class Simulation:
         timestep_procedures += self.pre_step_functions 
 
         comm_routine = [
-            (comm.exchange(), every_reneighbor_params),
-            (comm.borders(), comm.synchronize(), every_reneighbor_params)
+            (Exchange(self._comm), every_reneighbor_params),
+            (Borders(self._comm), Synchronize(self._comm), every_reneighbor_params)
             ]
         
         if self._generate_whole_program:
@@ -546,86 +606,25 @@ class Simulation:
         self.leave()
 
         # Combine everything into a whole program
-        if self._generate_whole_program:
-            assert self.grid, "No domain is created. Set domain bounds with 'set_domain'."
-
-            # Initialization and setup functions, together with time-step loop
-            # UpdateDomain is added after setup_particles because particles must be already present in the simulation
-            body = Block.from_list(self, [
-                self.create_domain,
-                self.setup_particles,
-                UpdateDomain(self),        
-                self.setup_functions,
-                BuildCellListsStencil(self, self.cell_lists),
-                timestep.as_block()
-            ])
-
-            program = Module(self, name='main', block=Block.merge_blocks(inits, body))
-
-            # Apply transformations
-            transformations = Transformations(program, self._target)
-            transformations.apply_all()
-
-            # Generate program
-            self.code_gen.generate_program(program)
-
-        # Generate a small library to be called
-        else:
-            update_cells = [m[0] if isinstance(m, tuple) else m for m in update_cells]
-            update_cells_module = Module(self, name='update_cells', block=Block.from_list(self, update_cells))
-
-            # Either generate a set_domain module, or create domain during initialization 
-            if self.generate_set_domain_module:
-                self.grid = MutableGrid(self, self.dims)
-                initialize_module = Module(self, name='initialize', block=inits)
-                set_domain_module = Module(self, name='set_domain', block=Block(self, SetDomain(self)), interface=True)
-            else:
-                initialize_module = Module(self, name='initialize', block=Block.merge_blocks(inits, self.create_domain))
-                set_domain_module = None
-                
-            setup_sim = Block.from_list(self, [
-                    self.setup_particles,
-                    UpdateDomain(self),
-                    self.setup_functions,
-                    BuildCellListsStencil(self, self.cell_lists),
-                ])
-
-            setup_sim_module = Module(self, name='setup_sim', block=setup_sim)
-            communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).block)
-            reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self)))
-            
-            modules_list = [
-                update_cells_module, 
-                initialize_module,
-                setup_sim_module, 
-                reverse_comm_module, 
-                communicate_module, 
-                reset_volatiles_module
-            ]
-
-            if self.generate_set_domain_module:
-                modules_list += [set_domain_module]
-
-            Transformations(modules_list, self._target).apply_all()
+        # Initialization and setup functions, together with time-step loop
+        # UpdateDomain is added after setup_particles because particles must be already present in the simulation
+        body = Block.from_list(self, [
+            self.create_domain,
+            self.setup_particles,
+            UpdateDomain(self),        
+            self.setup_functions,
+            BuildCellListsStencil(self, self.cell_lists),
+            timestep.as_block()
+        ])
 
-            # user defined modules are transformed seperately as indvidual modules 
-            # i.e. they are transformed once again if already transformed in setup_sim or do_timestep
-            udf_internal = self.setup_functions + self.pre_step_functions + self.functions
-            udf_internal = [m[0] if isinstance(m, tuple) else m for m in udf_internal]
-            user_defined_modules = [Module(self, name=m.name, block=Block(self, m), user_defined=True) for m in udf_internal]
-            for i, m in enumerate(user_defined_modules):
-                m._id = udf_internal[i]._id
+        program = Module(self, name='main', block=Block.merge_blocks(inits, body))
 
-            Transformations(user_defined_modules, self._target).apply_all()
+        # Apply transformations
+        transformations = Transformations(program, self._target)
+        transformations.apply_all()
 
-            # Generate library
-            self.code_gen.generate_library(update_cells_module, 
-                                           user_defined_modules, 
-                                           initialize_module, 
-                                           set_domain_module,
-                                           setup_sim_module, 
-                                           reverse_comm_module, 
-                                           communicate_module, 
-                                           reset_volatiles_module)
+        # Generate whole program
+        self.code_gen.generate_program(program)
 
+        # Generate getters for the runtime functions
         self.code_gen.generate_interfaces()