diff --git a/examples/dem.py b/examples/dem.py
index 9aec3a121ac26a2fe5c441c3fe812b59cc37db9d..ca8ac06cbc14bfa4e5b89634adb3f8668761ab63 100644
--- a/examples/dem.py
+++ b/examples/dem.py
@@ -161,6 +161,7 @@ psim.add_contact_property('tangential_spring_displacement', pairs.vector(), [0.0
 psim.add_contact_property('impact_velocity_magnitude', pairs.real(), 0.0)
 
 psim.set_domain([0.0, 0.0, 0.0, domainSize_SI[0], domainSize_SI[1], domainSize_SI[2]])
+psim.set_domain_partitioner(pairs.regular_domain_partitioner_xy())
 psim.pbc([True, True, False])
 psim.read_particle_data(
     "data/spheres.input",
diff --git a/examples/lj.py b/examples/lj.py
index 8f2414dedec1858460c1737182f2925a087d667b..7d555f87040a85c53a0c387b70bc1caa65a9969a 100644
--- a/examples/lj.py
+++ b/examples/lj.py
@@ -49,12 +49,16 @@ psim.add_property('force', pairs.vector(), volatile=True)
 psim.add_feature('type', ntypes)
 psim.add_feature_property('type', 'epsilon', pairs.real(), [sigma for i in range(ntypes * ntypes)])
 psim.add_feature_property('type', 'sigma6', pairs.real(), [epsilon for i in range(ntypes * ntypes)])
+
 psim.copper_fcc_lattice(nx, ny, nz, rho, temp, ntypes)
+psim.set_domain_partitioner(pairs.regular_domain_partitioner())
+psim.compute_thermo(100)
+
 psim.reneighbor_every(20)
 #psim.compute_half()
 psim.build_neighbor_lists(cutoff_radius + skin)
-psim.compute_thermo(100)
 #psim.vtk_output(f"output/lj_{target}")
+
 psim.compute(initial_integrate, symbols={'dt': dt}, pre_step=True, skip_first=True)
 psim.compute(lj, cutoff_radius)
 psim.compute(final_integrate, symbols={'dt': dt}, skip_first=True)
diff --git a/runtime/domain/regular_6d_stencil.cpp b/runtime/domain/regular_6d_stencil.cpp
index 8ac99deb16b4ec8bca89ced48617d7521b5ac495..0c6f100fd8afd93f473acab26504c236ec81f8f3 100644
--- a/runtime/domain/regular_6d_stencil.cpp
+++ b/runtime/domain/regular_6d_stencil.cpp
@@ -24,12 +24,13 @@ void Regular6DStencil::setConfig() {
         }
     }
 
-    for(int i = 1; i < world_size; i++) {
-        if(world_size % i == 0) {
-            const int rem_yz = world_size / i;
-            for(int j = 1; j < rem_yz; j++) {
-                if(rem_yz % j == 0) {
-                    const int k = rem_yz / j;
+    const int imax = partition_flags[0] ? world_size : 1;
+    const int jmax = partition_flags[1] ? world_size : 1;
+    const int kmax = partition_flags[2] ? world_size : 1;
+    for(int i = 1; i <= imax; i++) {
+        for(int j = 1; j <= jmax; j++) {
+            for(int k = 1; k <= kmax; k++) {
+                if((i * j * k) == world_size) {
                     const real_t surf = (area[0] / i / j) + (area[1] / i / k) + (area[2] / j / k);
                     if(surf < best_surf) {
                         nranks[0] = i;
diff --git a/runtime/domain/regular_6d_stencil.hpp b/runtime/domain/regular_6d_stencil.hpp
index 8a43b2d4954bc1fae4eff570a2972a8484dcf461..330af65a6ccb140cef8d283eac3ab183d0503c45 100644
--- a/runtime/domain/regular_6d_stencil.hpp
+++ b/runtime/domain/regular_6d_stencil.hpp
@@ -10,6 +10,7 @@ namespace pairs {
 class Regular6DStencil : public DomainPartitioner {
 private:
     int world_size, rank;
+    int *partition_flags;
     int *nranks;
     int *prev;
     int *next;
@@ -19,7 +20,8 @@ private:
     real_t *subdom_max;
 
 public:
-    Regular6DStencil(real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) :
+    Regular6DStencil(
+        real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax, const int part[]) :
         DomainPartitioner(xmin, xmax, ymin, ymax, zmin, zmax) {
 
         nranks = new int[ndims];
@@ -29,6 +31,11 @@ public:
         pbc_next = new int[ndims];
         subdom_min = new real_t[ndims];
         subdom_max = new real_t[ndims];
+        partition_flags = new int[ndims];
+
+        for(int d = 0; d < ndims; d++) {
+            partition_flags[d] = part[d];
+        }
     }
 
     ~Regular6DStencil() {
diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index 24e40a0c97fcd3f0de1c530d5de954f0a7dd4366..d4291b8fbbe02a87e317b2d864c96bdc55f05829 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -15,8 +15,12 @@ void PairsSimulation::initDomain(
     int *argc, char ***argv,
     real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax) {
 
-    if(dom_part_type == DimRanges) {
-        dom_part = new Regular6DStencil(xmin, xmax, ymin, ymax, zmin, zmax);
+    if(dom_part_type == Regular) {
+        const int flags[] = {1, 1, 1};
+        dom_part = new Regular6DStencil(xmin, xmax, ymin, ymax, zmin, zmax, flags);
+    } else if(dom_part_type == RegularXY) {
+        const int flags[] = {1, 1, 0};
+        dom_part = new Regular6DStencil(xmin, xmax, ymin, ymax, zmin, zmax, flags);
     } else {
         PAIRS_EXCEPTION("Domain partitioning type not implemented!\n");
     }
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index 0cfbefb78e2d3d5fc73513c18fd889a80c36a0bf..3a8a58f983dde43c90c1375362bcdf37065743c6 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -28,16 +28,16 @@ class PairsSimulation {
 private:
     Regular6DStencil *dom_part;
     //DomainPartitioner *dom_part;
+    DomainPartitioners dom_part_type;
     std::vector<Property> properties;
     std::vector<ContactProperty> contact_properties;
     std::vector<FeatureProperty> feature_properties;
     std::vector<Array> arrays;
     DeviceFlags *prop_flags, *contact_prop_flags, *array_flags;
-    DomainPartitioning dom_part_type;
     Timers<double> *timers;
 
 public:
-    PairsSimulation(int nprops_, int ncontactprops_, int narrays_, DomainPartitioning dom_part_type_) {
+    PairsSimulation(int nprops_, int ncontactprops_, int narrays_, DomainPartitioners dom_part_type_) {
         dom_part_type = dom_part_type_;
         prop_flags = new DeviceFlags(nprops_);
         contact_prop_flags = new DeviceFlags(ncontactprops_);
diff --git a/runtime/pairs_common.hpp b/runtime/pairs_common.hpp
index 525c0e3643ea4a86aad85d39f5c1ec7c15a31c54..61925a1746f52eb82eb6bb30ab5a0ca89208bfc8 100644
--- a/runtime/pairs_common.hpp
+++ b/runtime/pairs_common.hpp
@@ -44,9 +44,10 @@ enum Timers {
     Offset = 3
 };
 
-enum DomainPartitioning {
-    DimRanges = 0,
-    BoxList,
+enum DomainPartitioners {
+    Regular = 0,
+    RegularXY = 1,
+    BoxList = 2,
 };
 
 #ifdef DEBUG
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index 7ab7b1a7a1af753dd5b3dcfbe8b4a41a429b13ff..01126402422b66000359e9fffc7f964288d3df73 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -1,6 +1,7 @@
 from pairs.ir.types import Types
 from pairs.code_gen.cgen import CGen
 from pairs.code_gen.target import Target
+from pairs.sim.domain_partitioners import DomainPartitioners
 from pairs.sim.shapes import Shapes
 from pairs.sim.simulation import Simulation
 
@@ -43,3 +44,9 @@ def sphere():
 
 def halfspace():
     return Shapes.Halfspace
+
+def regular_domain_partitioner():
+    return DomainPartitioners.Regular
+
+def regular_domain_partitioner_xy():
+    return DomainPartitioners.RegularXY
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index e5c0300d3232f5edff01a14583f9e6bc5edb32a8..3b36c03204dba86febf95cfa4aab9a12fcb131df 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -29,6 +29,7 @@ from pairs.ir.types import Types
 from pairs.ir.utils import Print
 from pairs.ir.variables import Var, DeclareVariable, Deref
 from pairs.ir.vectors import Vector, VectorAccess, VectorOp, ZeroVector
+from pairs.sim.domain_partitioners import DomainPartitioners
 from pairs.sim.timestep import Timestep
 from pairs.code_gen.printer import Printer
 
@@ -115,8 +116,10 @@ class CGen:
             nprops = module.sim.properties.nprops()
             ncontactprops = module.sim.contact_properties.nprops()
             narrays = module.sim.arrays.narrays()
+            part = DomainPartitioners.c_keyword(module.sim.partitioner())
+
             self.print("int main(int argc, char **argv) {")
-            self.print(f"    PairsSimulation *pairs = new PairsSimulation({nprops}, {ncontactprops}, {narrays}, DimRanges);")
+            self.print(f"    PairsSimulation *pairs = new PairsSimulation({nprops}, {ncontactprops}, {narrays}, {part});")
 
             if module.sim._enable_profiler:
                 self.print("    LIKWID_MARKER_INIT;")
diff --git a/src/pairs/sim/domain_partitioners.py b/src/pairs/sim/domain_partitioners.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e00ad848f8496fbf8c24aa36ac3a06e813dcc07
--- /dev/null
+++ b/src/pairs/sim/domain_partitioners.py
@@ -0,0 +1,11 @@
+class DomainPartitioners:
+    Invalid = -1
+    Regular = 0
+    RegularXY = 1
+    BoxList = 2
+
+    def c_keyword(layout):
+        return "Regular"    if layout == DomainPartitioners.Regular else \
+               "RegularXY"  if layout == DomainPartitioners.RegularXY else \
+               "BoxList"    if layout == DomainPartitioners.BoxList else \
+               "Invalid"
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 37726abd2bf569cf6a4a3ec69a58df238ec19176..a1fefa45c053670b66665e4169212ad0f71c0b8a 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -17,6 +17,7 @@ from pairs.sim.comm import Comm
 from pairs.sim.contact_history import ContactHistory, BuildContactHistory, ClearUnusedContactHistory, ResetContactHistoryUsageStatus
 from pairs.sim.copper_fcc_lattice import CopperFCCLattice
 from pairs.sim.domain import InitializeDomain
+from pairs.sim.domain_partitioners import DomainPartitioners
 from pairs.sim.domain_partitioning import DimensionRanges
 from pairs.sim.features import AllocateFeatureProperties
 from pairs.sim.grid import Grid2D, Grid3D
@@ -76,8 +77,9 @@ class Simulation:
         self.reneighbor_frequency = 1
         self.vtk_file = None
         self.vtk_frequency = 0
+        self._dom_part = None
+        self._partitioner = None
         self._target = None
-        self._dom_part = DimensionRanges(self)
         self._pbc = [True for _ in range(dims)]
         self._use_contact_history = use_contact_history
         self._contact_history = ContactHistory(self) if use_contact_history else None
@@ -87,6 +89,18 @@ class Simulation:
         self._enable_profiler = False
         self._compute_thermo = 0
 
+    def set_domain_partitioner(self, partitioner):
+        self._partitioner = partitioner
+
+        if partitioner in (DomainPartitioners.Regular, DomainPartitioners.RegularXY):
+            self._dom_part = DimensionRanges(self)
+
+        else:
+            raise Exception("Invalid domain partitioner.")
+
+    def partitioner(self):
+        return self._partitioner
+
     def enable_profiler(self):
         self._enable_profiler = True