diff --git a/examples/dem_sd.py b/examples/dem_sd.py
index 94bc93f332733e7f03a0a7c440f81d4707984ae5..da81aad4002cfb84d78f0cabed2ce40ab524198f 100644
--- a/examples/dem_sd.py
+++ b/examples/dem_sd.py
@@ -132,8 +132,7 @@ psim.add_feature('type', ntypes)
 psim.add_feature_property('type', 'friction_static', pairs.real(), [frictionStatic for i in range(ntypes * ntypes)])
 psim.add_feature_property('type', 'friction_dynamic', pairs.real(), [frictionDynamic for i in range(ntypes * ntypes)])
 
-psim.set_domain([0.0, 0.0, 0.0, domainSize_SI[0], domainSize_SI[1], domainSize_SI[2]])
-# psim.set_domain_partitioner(pairs.block_forest(), initDomainFromWalberla=True)
+# psim.set_domain([0.0, 0.0, 0.0, domainSize_SI[0], domainSize_SI[1], domainSize_SI[2]])
 psim.set_domain_partitioner(pairs.block_forest())
 # psim.set_domain_partitioner(pairs.regular_domain_partitioner())
 psim.pbc([False, False, False])
diff --git a/examples/main.cpp b/examples/main.cpp
index 0bae49453d9393c74ee4d6dbd3819c9054d248e5..42a9575e000df84e33af7aa51adb683f3b94c5d6 100644
--- a/examples/main.cpp
+++ b/examples/main.cpp
@@ -25,7 +25,7 @@ int main(int argc, char **argv) {
     pairs_sim->initialize();
 
     // either create new domain or use an existing one ----------------------------------------
-    pairs_sim->create_domain(argc, argv);
+    pairs_sim->set_domain(argc, argv, 0, 0, 0, 0.1, 0.1, 0.1);
 
     // pairs_sim->use_domain(forest);
 
diff --git a/runtime/domain/domain_partitioning.hpp b/runtime/domain/domain_partitioning.hpp
index 15c56f3a509e784fa88bf2c9bc45abe926fd5916..48d96462509061c0e231cef97341e09884493f7c 100644
--- a/runtime/domain/domain_partitioning.hpp
+++ b/runtime/domain/domain_partitioning.hpp
@@ -37,6 +37,8 @@ public:
         delete[] grid_max;
     }
 
+    double getMin(int dim) const { return grid_min[dim]; }
+    double getMax(int dim) const { return grid_max[dim]; }
     virtual void initialize(int *argc, char ***argv) = 0;
     virtual void update() = 0;
     virtual int getWorldSize() const = 0;
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 6795e2d239ea5918a84701ad53eed47f2b99128d..766dcba2121371f82768829fb58fcad289053fba 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -172,20 +172,27 @@ class CGen:
         self.print("using namespace pairs;")
         self.print("")
 
-    def generate_module_headers(self):
-        self.print("")
+    def generate_module_header(self, module, definition=True):
+        module_params = []
+        if not module.user_defined and not module.interface:
+            module_params += ["PairsRuntime *pairs_runtime", "struct PairsObjects *pobj"]
+
+        if module.name=="init_domain" or module.name=="set_domain":
+            module_params += ["int argc", "char **argv"]
+
+        module_params += [f"{Types.c_keyword(self.sim, param.type())} {param.name()}" for param in module.parameters()]
 
-        self.print("// Module headers")
+        print_params = ", ".join(module_params)
+        ending = "{" if definition else ";"
+        self.print(f"void {module.name}({print_params}){ending}")
 
+    def generate_module_decls(self):
+        self.print("")
         self.print("namespace pairs::internal {")
         self.print.add_indent(4)
 
         for module in self.sim.modules():
-            if module.name != "main" and not module.user_defined:
-                module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}"
-                                                for param in module.parameters())
-                module_params = ", " + module_params if module_params else ""
-                self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params});")
+            self.generate_module_header(module, definition=False)
         
         self.print.add_indent(-4)
         self.print("}")
@@ -370,7 +377,7 @@ class CGen:
         self.print.start()
         self.generate_preamble()
         self.generate_pairs_object_structure()
-        self.generate_module_headers()
+        self.generate_module_decls()
 
         self.print("namespace pairs::internal {")
         self.print.add_indent(4)
@@ -391,7 +398,16 @@ class CGen:
 
         self.print.end()
 
-    def generate_library(self, update_cells_module, user_defined_modules, initialize_module, create_domain_module, setup_sim_module, reverse_comm_module, communicate_module, reset_volatiles_module):
+    def generate_library(self, 
+                         update_cells_module, 
+                         user_defined_modules, 
+                         initialize_module,
+                         set_domain_module, 
+                         setup_sim_module, 
+                         reverse_comm_module, 
+                         communicate_module, 
+                         reset_volatiles_module):
+        
         self.generate_interfaces()
         # Generate CUDA/CPP file with modules
         ext = ".cu" if self.target.is_gpu() else ".cpp"
@@ -425,7 +441,7 @@ class CGen:
             self.generate_kernel(kernel)
 
         for module in self.sim.modules():
-            if module.name not in ['update_cells', 'initialize', 'create_domain', 'setup_sim', 'reverse_comm', 'communicate', 'reset_volatiles']:
+            if module.name not in ['update_cells', 'initialize', 'set_domain', 'setup_sim', 'reverse_comm', 'communicate', 'reset_volatiles']:
                 if not module.user_defined:
                     self.generate_module(module)
 
@@ -441,7 +457,7 @@ class CGen:
 
         self.generate_preamble()
         self.generate_pairs_object_structure()
-        self.generate_module_headers()
+        self.generate_module_decls()
 
         ndims = module.sim.ndims()
         nprops = module.sim.properties.nprops()
@@ -463,17 +479,19 @@ class CGen:
             self.generate_module(module)
             self.print("")
         
-        self.print("void initialize() {")
+        if set_domain_module:
+            self.print("void initialize() {")
+        else:
+            self.print("void initialize(int argc, char **argv) {")
         self.print(f"    pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});")
         self.print(f"    pobj = new PairsObjects();")
         self.generate_statement(initialize_module.block)
         self.print("}")
         self.print("")
 
-        self.print("void create_domain(int argc, char **argv) {")
-        self.generate_statement(create_domain_module.block)
-        self.print("}")
-        self.print("")
+        if set_domain_module:
+            self.generate_module(set_domain_module)
+            self.print("")
 
         self.print("template<typename Domain_T>")
         self.print("void use_domain(const std::shared_ptr<Domain_T> &domain_ptr) {")
@@ -631,22 +649,13 @@ class CGen:
         self.generate_full_object_names = False
 
     def generate_module(self, module):
-        module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}"
-                                            for param in module.parameters())
-        if not module.user_defined:
-            module_params = ", " + module_params if module_params else ""
-            self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{")
-        else:
-            
-            self.print(f"void {module.name}({module_params}) {{")
-
-
+        self.generate_module_header(module, definition=True)
         self.print.add_indent(4)
 
         if self.debug:
             self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");")
 
-        if not module.user_defined:
+        if not module.user_defined and not module.interface:
             self.generate_module_declerations(module)
 
         self.print.add_indent(-4)
@@ -717,14 +726,8 @@ class CGen:
         self.print.add_indent(4)
         self.kernel_context = True
 
-        # if has_resizes:
-            # self.print(f"printf(\"{kernel.name} @@@@@@@@ before kernel: resizes[0] = %d\\n\", resizes[0]);")
-
         self.generate_statement(kernel.block)
 
-        # if has_resizes:
-            # self.print(f"printf(\"{kernel.name} @@@@@@@@ after kernel: resizes[0] = %d\\n\", resizes[0]);")
-
         self.kernel_context = False
         self.print.add_indent(-4)
         self.print("}")
@@ -1070,9 +1073,15 @@ class CGen:
             self.print("}")
 
         if isinstance(ast_node, ModuleCall):
-            module_params = ", ".join(f"{param.name()}" for param in ast_node.module.parameters())
-            module_params = ", " + module_params if module_params else ""
-            self.print(f"pairs::internal::{ast_node.module.name}(pairs_runtime, pobj{module_params});")
+            module_params = ["pairs_runtime", "pobj"]
+
+            if ast_node.module.name=="init_domain":
+                module_params += ["argc", "argv"]
+
+            module_params += [f"{param.name()}" for param in ast_node.module.parameters()]
+
+            print_params = ", ".join(module_params)
+            self.print(f"pairs::internal::{ast_node.module.name}({print_params});")
 
         if isinstance(ast_node, Print):
             args = ast_node.args
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index 52ee117c4eeaa193f0a2fb382fb8e64e3db1136e..04b4f850d57ed8c0e1ad4b28a4a9f30cce8384da 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -10,7 +10,7 @@ from pairs.ir.parameters import Parameter
 class Module(ASTNode):
     last_module = 0
 
-    def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False, user_defined=False):
+    def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False, user_defined=False, interface=False):
         super().__init__(sim)
         self._id = Module.last_module
         self._name = name if name is not None else "module" + str(Module.last_module)
@@ -26,6 +26,7 @@ class Module(ASTNode):
         self._check_properties_resize = check_properties_resize
         self._run_on_device = run_on_device
         self._user_defined = user_defined
+        self._interface = interface
         self._profile = False
         sim.add_module(self)
         Module.last_module += 1
@@ -53,6 +54,10 @@ class Module(ASTNode):
     def user_defined(self):
         return self._user_defined
 
+    @property
+    def interface(self):
+        return self._interface
+
     def profile(self):
         self._profile = True
         self.sim.enable_profiler()
diff --git a/src/pairs/sim/domain.py b/src/pairs/sim/domain.py
index 2e0d9877d28ca5f2dfd763ef21b20307ab5d752e..559384024a6999221814f086714dde8bca114374 100644
--- a/src/pairs/sim/domain.py
+++ b/src/pairs/sim/domain.py
@@ -1,4 +1,7 @@
-from pairs.ir.block import pairs_inline
+from pairs.ir.block import pairs_inline, pairs_host_block
+from pairs.ir.parameters import Parameter
+from pairs.ir.types import Types
+from pairs.ir.assign import Assign
 from pairs.sim.lowerable import Lowerable
 
 
@@ -10,6 +13,23 @@ class InitializeDomain(Lowerable):
     def lower(self):
         self.sim.domain_partitioning().initialize()
 
+class SetDomain(Lowerable):
+    def __init__(self, sim):
+        super().__init__(sim)
+
+    @pairs_host_block
+    def lower(self):
+        self.sim.module_name('init_domain')
+
+        for d in range(self.sim.ndims()):
+            dmin = Parameter(self.sim, f'd{d}_min', Types.Real)
+            Assign(self.sim, self.sim.grid.min(d), dmin)
+
+        for d in range(self.sim.ndims()):
+            dmax = Parameter(self.sim, f'd{d}_max', Types.Real)
+            Assign(self.sim, self.sim.grid.max(d), dmax)
+
+        self.sim.domain_partitioning().initialize()
 
 class UpdateDomain(Lowerable):
     def __init__(self, sim):
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index e3d95c06421ea9fab2ac09a3f111b564fe907f17..de3839618df03d0c09e1a3cd2d2d36b3b5ee2d77 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -1,12 +1,13 @@
 from pairs.ir.assign import Assign
 from pairs.ir.branches import Filter
 from pairs.ir.loops import For
-from pairs.ir.functions import Call_Int, Call_Void
+from pairs.ir.functions import Call_Int, Call_Void, Call
 from pairs.ir.scalars import ScalarOp
 from pairs.ir.select import Select
 from pairs.ir.types import Types
 from pairs.sim.flags import Flags
 from pairs.ir.lit import Lit
+from pairs.sim.grid import MutableGrid
 
 
 class DimensionRanges:
@@ -44,14 +45,19 @@ class DimensionRanges:
     def initialize(self):
         grid_array = [(self.sim.grid.min(d), self.sim.grid.max(d)) for d in range(self.sim.ndims())]
         Call_Void(self.sim, "pairs_runtime->initDomain", [param for delim in grid_array for param in delim])
-        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['neighbor_ranks', self.neighbor_ranks, self.sim.ndims() * 2])
-        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['pbc', self.pbc, self.sim.ndims() * 2])
-        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
 
     def update(self):
         Call_Void(self.sim, "pairs_runtime->updateDomain", [])
         Assign(self.sim, self.rank, Call_Int(self.sim, "pairs_runtime->getDomainPartitioner()->getRank", []))
 
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['neighbor_ranks', self.neighbor_ranks, self.sim.ndims() * 2])
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['pbc', self.pbc, self.sim.ndims() * 2])
+        Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
+
+        if isinstance(self.sim.grid, MutableGrid):
+            for d in range(self.sim.dims):
+                Assign(self.sim, self.sim.grid.min(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMin", [d], Types.Real))
+                Assign(self.sim, self.sim.grid.max(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMax", [d], Types.Real))
 
     def ghost_particles(self, step, position, offset=0.0):
         # Particles with one of the following flags are ignored
@@ -155,6 +161,11 @@ class BlockForest:
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabb_offsets', self.aabb_offsets, self.nranks])
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['aabbs', self.aabbs, self.ntotal_aabbs * 6])
             Call_Void(self.sim, "pairs_runtime->copyRuntimeArray", ['subdom', self.subdom, self.sim.ndims() * 2])
+        
+        if isinstance(self.sim.grid, MutableGrid):
+            for d in range(self.sim.dims):
+                Assign(self.sim, self.sim.grid.min(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMin", [d], Types.Real))
+                Assign(self.sim, self.sim.grid.max(d), Call(self.sim, "pairs_runtime->getDomainPartitioner()->getMax", [d], Types.Real))
 
     def ghost_particles(self, step, position, offset=0.0):
         ''' TODO :  If we have pbc, a sinlge particle can be a ghost particle multiple times (at different locations) for the same neighbor block,
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 20b439941bfb23719bd136ee6d6db460f0053bfb..208f1ad3231417de9e1719d8c880e2f70c20a62a 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -17,11 +17,11 @@ from pairs.sim.comm import Comm
 from pairs.sim.contact_history import ContactHistory, BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus
 from pairs.sim.copper_fcc_lattice import CopperFCCLattice
 from pairs.sim.dem_sc_grid import DEMSCGrid
-from pairs.sim.domain import InitializeDomain, UpdateDomain
+from pairs.sim.domain import InitializeDomain, UpdateDomain, SetDomain
 from pairs.sim.domain_partitioners import DomainPartitioners
 from pairs.sim.domain_partitioning import BlockForest, DimensionRanges
 from pairs.sim.features import AllocateFeatureProperties
-from pairs.sim.grid import Grid2D, Grid3D
+from pairs.sim.grid import Grid2D, Grid3D, MutableGrid
 from pairs.sim.instrumentation import RegisterMarkers, RegisterTimers
 from pairs.sim.lattice import ParticleLattice
 from pairs.sim.neighbor_lists import NeighborLists, BuildNeighborLists
@@ -92,6 +92,8 @@ class Simulation:
 
         # Different segments of particle code/functions
         self.create_domain = Block(self, [])
+        self.generate_set_domain_module = True
+
         self.setup_particles = Block(self, [])
         self.module_list = []
         self.kernel_list = []
@@ -265,6 +267,11 @@ class Simulation:
         return self.vars.find(var_name)
 
     def set_domain(self, grid):
+        """Set domain bounds. 
+        If the domain is set through this function, the 'set_domain' module won't be generated in the modular version.
+        Use this function only if you do not need to set domain at runtime.
+        This function is required only for whole-program generation."""
+        self.generate_set_domain_module = False
         self.grid = Grid3D(self, grid[0], grid[1], grid[2], grid[3], grid[4], grid[5])
         self.create_domain.add_statement(InitializeDomain(self))
 
@@ -532,6 +539,8 @@ class Simulation:
 
         # Combine everything into a whole program
         if self._generate_whole_program:
+            assert self.grid, "No domain is created. Set domain bounds with 'set_domain'."
+
             # Initialization and setup functions, together with time-step loop
             # UpdateDomain is added after setup_particles because particles must be already present in the simulation
             body = Block.from_list(self, [
@@ -557,9 +566,15 @@ class Simulation:
             update_cells = [m[0] if isinstance(m, tuple) else m for m in update_cells]
             update_cells_module = Module(self, name='update_cells', block=Block.from_list(self, update_cells))
 
-            initialize_module = Module(self, name='initialize', block=inits)
-            create_domain_module = Module(self, name='create_domain', block=self.create_domain)
-
+            # Either generate a set_domain module, or create domain during initialization 
+            if self.generate_set_domain_module:
+                self.grid = MutableGrid(self, self.dims)
+                initialize_module = Module(self, name='initialize', block=inits)
+                set_domain_module = Module(self, name='set_domain', block=Block(self, SetDomain(self)), interface=True)
+            else:
+                initialize_module = Module(self, name='initialize', block=Block.merge_blocks(inits, self.create_domain))
+                set_domain_module = None
+                
             setup_sim = Block.from_list(self, [
                     self.setup_particles,
                     UpdateDomain(self),
@@ -570,17 +585,19 @@ class Simulation:
             setup_sim_module = Module(self, name='setup_sim', block=setup_sim)
             communicate_module = Module(self, name='communicate', block=Timestep(self, 0, comm_routine).as_block())
             reset_volatiles_module = Module(self, name='reset_volatiles', block=Block(self, ResetVolatileProperties(self)))
-
+            
             modules_list = [
                 update_cells_module, 
-                initialize_module, 
-                create_domain_module, 
+                initialize_module,
                 setup_sim_module, 
                 reverse_comm_module, 
                 communicate_module, 
                 reset_volatiles_module
             ]
 
+            if self.generate_set_domain_module:
+                modules_list += [set_domain_module]
+
             Transformations(modules_list, self._target).apply_all()
 
             # user defined modules are transformed seperately as indvidual modules 
@@ -594,7 +611,7 @@ class Simulation:
             self.code_gen.generate_library(update_cells_module, 
                                            user_defined_modules, 
                                            initialize_module, 
-                                           create_domain_module, 
+                                           set_domain_module,
                                            setup_sim_module, 
                                            reverse_comm_module, 
                                            communicate_module,