diff --git a/AUTHORS.txt b/AUTHORS.txt
index d591db3e7e1f226f7cccfdf29478b5cf2f42124d..eb0653f6fe72c59e1f029402b3b4b26bcd9170f0 100644
--- a/AUTHORS.txt
+++ b/AUTHORS.txt
@@ -11,3 +11,4 @@ Contributors:
   - Rudolf Weeber <weeber@icp.uni-stuttgart.de>
   - Christian Godenschwager <christian.godenschwager@fau.de>
   - Jan Hönig <jan.hoenig@fau.de>
+  - Philipp Suffa <philipp.suffa@fau.de>
diff --git a/src/lbmpy/creationfunctions.py b/src/lbmpy/creationfunctions.py
index 9bcf0f6eecf9b1924e1c97eebadf198f61125d89..5e2537193f2f1b2168b79d32c2eef30b1dbb9298 100644
--- a/src/lbmpy/creationfunctions.py
+++ b/src/lbmpy/creationfunctions.py
@@ -67,7 +67,8 @@ from lbmpy.enums import Stencil, Method, ForceModel, CollisionSpace, SubgridScal
 import lbmpy.forcemodels as forcemodels
 from lbmpy.fieldaccess import CollideOnlyInplaceAccessor, PdfFieldAccessor, PeriodicTwoFieldsAccessor
 from lbmpy.fluctuatinglb import add_fluctuations_to_collision_rule
-from lbmpy.partially_saturated_cells import add_psm_to_collision_rule, PSMConfig
+from lbmpy.partially_saturated_cells import (replace_by_psm_collision_rule, PSMConfig,
+                                             add_psm_solid_collision_to_collision_rule)
 from lbmpy.non_newtonian_models import add_cassons_model, CassonsParameters
 from lbmpy.methods import (create_mrt_orthogonal, create_mrt_raw, create_central_moment,
                            create_srt, create_trt, create_trt_kbc)
@@ -468,7 +469,7 @@ class LBMConfig:
         }
 
         if self.psm_config is not None and self.psm_config.fraction_field is not None:
-            self.force = [(1.0 - self.psm_config.fraction_field.center) * f for f in self.force]
+            self.force = [(1.0 - self.psm_config.fraction_field_symbol) * f for f in self.force]
 
         if isinstance(self.force_model, str):
             new_force_model = ForceModel[self.force_model.upper()]
@@ -684,11 +685,6 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
     else:
         collision_rule = lb_method.get_collision_rule(pre_simplification=pre_simplification)
 
-    if lbm_config.psm_config is not None:
-        if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None:
-            raise ValueError("Specify a fraction and object velocity field in the PSM Config")
-        collision_rule = add_psm_to_collision_rule(collision_rule, lbm_config.psm_config)
-
     if lbm_config.galilean_correction:
         from lbmpy.methods.cumulantbased import add_galilean_correction
         collision_rule = add_galilean_correction(collision_rule)
@@ -706,6 +702,11 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
                                                      bulk_relaxation_rate=lbm_config.relaxation_rates[1],
                                                      limiter=cumulant_limiter)
 
+    if lbm_config.psm_config is not None:
+        if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None:
+            raise ValueError("Specify a fraction and object velocity field in the PSM Config")
+        collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config)
+
     if lbm_config.entropic:
         if lbm_config.subgrid_scale_model or lbm_config.cassons:
             raise ValueError("Choose either entropic, subgrid-scale or cassons")
@@ -783,7 +784,7 @@ def create_lb_method(lbm_config=None, **params):
     if lbm_config.psm_config is None:
         fraction_field = None
     else:
-        fraction_field = lbm_config.psm_config.fraction_field
+        fraction_field = lbm_config.psm_config.fraction_field_symbol
 
     common_params = {
         'compressible': lbm_config.compressible,
@@ -869,49 +870,36 @@ def create_lb_method(lbm_config=None, **params):
 
 
 def create_psm_update_rule(lbm_config, lbm_optimisation):
-    node_collection = []
 
-    # Use regular lb update rule for no overlapping particles
-    config_without_psm = copy.deepcopy(lbm_config)
-    config_without_psm.psm_config = None
-    # TODO: the force is still multiplied by (1.0 - self.psm_config.fraction_field.center)
-    #  (should not harm if memory bound since self.psm_config.fraction_field.center should always be 0.0)
+    if lbm_config.psm_config is None:
+        raise ValueError("Specify a PSM Config in the LBM Config, when creating a psm update rule")
+
+    config_without_particles = copy.deepcopy(lbm_config)
+    config_without_particles.psm_config.max_particles_per_cell = 0
+
     lb_update_rule = create_lb_update_rule(
-        lbm_config=config_without_psm, lbm_optimisation=lbm_optimisation
-    )
-    node_collection.append(
-        Conditional(
-            lbm_config.psm_config.fraction_field.center(0) <= 0.0,
-            Block(lb_update_rule.all_assignments),
-        )
-    )
+        lbm_config=config_without_particles, lbm_optimisation=lbm_optimisation)
+
+    node_collection = lb_update_rule.all_assignments
 
-    # Only one particle, i.e., no individual_fraction_field is provided
     if lbm_config.psm_config.individual_fraction_field is None:
-        assert lbm_config.psm_config.MaxParticlesPerCell == 1
+        assert lbm_config.psm_config.max_particles_per_cell == 1
+        fraction_field = lbm_config.psm_config.fraction_field
+    else:
+        fraction_field = lbm_config.psm_config.individual_fraction_field
+
+    for p in range(lbm_config.psm_config.max_particles_per_cell):
+
+        psm_solid_collision = add_psm_solid_collision_to_collision_rule(lb_update_rule, lbm_config, p)
         psm_update_rule = create_lb_update_rule(
-            lbm_config=lbm_config, lbm_optimisation=lbm_optimisation
-        )
+            collision_rule=psm_solid_collision, lbm_config=lbm_config, lbm_optimisation=lbm_optimisation)
+
         node_collection.append(
             Conditional(
-                lbm_config.psm_config.fraction_field.center(0) > 0.0,
+                fraction_field.center(p) > 0.0,
                 Block(psm_update_rule.all_assignments),
             )
         )
-    else:
-        for p in range(lbm_config.psm_config.MaxParticlesPerCell):
-            # Add psm update rule for p overlapping particles
-            config_with_p_particles = copy.deepcopy(lbm_config)
-            config_with_p_particles.psm_config.MaxParticlesPerCell = p + 1
-            psm_update_rule = create_lb_update_rule(
-                lbm_config=config_with_p_particles, lbm_optimisation=lbm_optimisation
-            )
-            node_collection.append(
-                Conditional(
-                    lbm_config.psm_config.individual_fraction_field.center(p) > 0.0,
-                    Block(psm_update_rule.all_assignments),
-                )
-            )
 
     return NodeCollection(node_collection)
 
diff --git a/src/lbmpy/custom_code_nodes.py b/src/lbmpy/custom_code_nodes.py
index 6540f382c1f44a84cc6cceb94c653c5166709237..2604a2b944265529f6cd1862bc92d7b11c9154c3 100644
--- a/src/lbmpy/custom_code_nodes.py
+++ b/src/lbmpy/custom_code_nodes.py
@@ -68,7 +68,7 @@ class LbmWeightInfo(CustomCodeNode):
         weights = [f"(({self.weights_symbol.dtype.c_name})({str(w.evalf(17))}))" for w in lb_method.weights]
         weights = ", ".join(weights)
         w_sym = self.weights_symbol
-        code = f"const {self.weights_symbol.dtype.c_name} {w_sym.name} [] = {{{ weights }}};\n"
+        code = f"const {self.weights_symbol.dtype.c_name} {w_sym.name} [] = {{{weights}}};\n"
         super(LbmWeightInfo, self).__init__(code, symbols_read=set(), symbols_defined={w_sym})
 
     def weight_of_direction(self, dir_idx, lb_method=None):
diff --git a/src/lbmpy/methods/creationfunctions.py b/src/lbmpy/methods/creationfunctions.py
index 01cd85c61fdb0d1b71efaaf280d57d0df420c02b..eece583fb2e19c569882abfa4b3d6938667d6a65 100644
--- a/src/lbmpy/methods/creationfunctions.py
+++ b/src/lbmpy/methods/creationfunctions.py
@@ -372,7 +372,7 @@ def create_central_moment(stencil, relaxation_rates, nested_moments=None,
 
     rr_dict = _get_relaxation_info_dict(relaxation_rates, nested_moments, stencil.D, conserved_moments)
     if fraction_field is not None:
-        relaxation_rates_modifier = (1.0 - fraction_field.center)
+        relaxation_rates_modifier = (1.0 - fraction_field)
         rr_dict = _get_relaxation_info_dict(relaxation_rates, nested_moments, stencil.D,
                                             relaxation_rates_modifier=relaxation_rates_modifier)
 
@@ -548,7 +548,7 @@ def create_cumulant(stencil, relaxation_rates, cumulant_groups, conserved_moment
     cumulant_to_rr_dict = _get_relaxation_info_dict(relaxation_rates, cumulant_groups, stencil.D, conserved_moments)
 
     if fraction_field is not None:
-        relaxation_rates_modifier = (1.0 - fraction_field.center)
+        relaxation_rates_modifier = (1.0 - fraction_field)
         cumulant_to_rr_dict = _get_relaxation_info_dict(relaxation_rates, cumulant_groups, stencil.D,
                                                         relaxation_rates_modifier=relaxation_rates_modifier)
 
diff --git a/src/lbmpy/methods/momentbased/momentbasedmethod.py b/src/lbmpy/methods/momentbased/momentbasedmethod.py
index 740834072a0210772d407a7e441b9e1ef46d0f7a..1168889e89d27e8cac54ea3d4fb624f5c82e51c9 100644
--- a/src/lbmpy/methods/momentbased/momentbasedmethod.py
+++ b/src/lbmpy/methods/momentbased/momentbasedmethod.py
@@ -177,7 +177,7 @@ class MomentBasedLbMethod(AbstractLbMethod):
                            pre_simplification: bool = True) -> LbmCollisionRule:
 
         if self.fraction_field is not None:
-            relaxation_rates_modifier = (1.0 - self.fraction_field.center)
+            relaxation_rates_modifier = (1.0 - self.fraction_field)
             rr_sub_expressions, d = self._generate_symbolic_relaxation_matrix(
                 relaxation_rates_modifier=relaxation_rates_modifier)
         else:
diff --git a/src/lbmpy/partially_saturated_cells.py b/src/lbmpy/partially_saturated_cells.py
index 0798ac69501b46ffc22251ff13e7f5ecc036d3b9..35196e498fec412a458217e907fb68d10ea91536 100644
--- a/src/lbmpy/partially_saturated_cells.py
+++ b/src/lbmpy/partially_saturated_cells.py
@@ -1,6 +1,7 @@
 import sympy as sp
 from dataclasses import dataclass
 
+from lbmpy.enums import Method
 from lbmpy.methods.abstractlbmethod import LbmCollisionRule
 from pystencils import Assignment, AssignmentCollection
 from pystencils.field import Field
@@ -13,103 +14,154 @@ class PSMConfig:
     Fraction field for PSM 
     """
 
+    fraction_field_symbol = sp.Symbol('B')
+    """
+    Fraction field symbol used for simplification 
+    """
+
     object_velocity_field: Field = None
     """
     Object velocity field for PSM 
     """
 
-    SC: int = 1
+    solid_collision: int = 1
     """
     Solid collision option for PSM
     """
 
-    MaxParticlesPerCell: int = 1
+    max_particles_per_cell: int = 1
     """
     Maximum number of particles overlapping with a cell 
     """
 
     individual_fraction_field: Field = None
     """
-    Fraction field for each overlapping particle in PSM 
+    Fraction field for each overlapping object / particle in PSM 
     """
 
-    particle_force_field: Field = None
+    object_force_field: Field = None
     """
-    Force field for each overlapping particle in PSM 
+    Force field for each overlapping object / particle in PSM 
     """
 
 
-def add_psm_to_collision_rule(collision_rule, psm_config):
+def get_psm_solid_collision_term(collision_rule, psm_config, particle_per_cell_counter):
     if psm_config.individual_fraction_field is None:
-        psm_config.individual_fraction_field = psm_config.fraction_field
+        fraction_field = psm_config.fraction_field
+    else:
+        fraction_field = psm_config.individual_fraction_field
 
     method = collision_rule.method
     pre_collision_pdf_symbols = method.pre_collision_pdf_symbols
     stencil = method.stencil
 
-    # Get equilibrium from object velocity for solid collision
-    forces_rhs = [0] * psm_config.MaxParticlesPerCell * stencil.D
     solid_collisions = [0] * stencil.Q
-    for p in range(psm_config.MaxParticlesPerCell):
-        equilibrium_fluid = method.get_equilibrium_terms()
-        equilibrium_solid = []
-        for eq in equilibrium_fluid:
-            eq_sol = eq
-            for i in range(stencil.D):
-                eq_sol = eq_sol.subs(sp.Symbol("u_" + str(i)),
-                                     psm_config.object_velocity_field.center(p * stencil.D + i), )
-            equilibrium_solid.append(eq_sol)
-
-        # Build solid collision
-        for i, (eqFluid, eqSolid, f, offset) in enumerate(
-                zip(equilibrium_fluid, equilibrium_solid, pre_collision_pdf_symbols, stencil)):
-            inverse_direction_index = stencil.stencil_entries.index(stencil.inverse_stencil_entries[i])
-            if psm_config.SC == 1:
-                solid_collision = psm_config.individual_fraction_field.center(p) * (
-                    (
-                        pre_collision_pdf_symbols[inverse_direction_index]
-                        - equilibrium_fluid[inverse_direction_index]
-                    )
-                    - (f - eqSolid)
-                )
-            elif psm_config.SC == 2:
-                # TODO get relaxation rate vector from method and use the right relaxation rate [i]
-                solid_collision = psm_config.individual_fraction_field.center(p) * (
-                    (eqSolid - f) + (1.0 - method.relaxation_rates[0]) * (f - eqFluid)
-                )
-            elif psm_config.SC == 3:
-                solid_collision = psm_config.individual_fraction_field.center(p) * (
-                    (
-                        pre_collision_pdf_symbols[inverse_direction_index]
-                        - equilibrium_solid[inverse_direction_index]
-                    )
-                    - (f - eqSolid)
-                )
-            else:
-                raise ValueError("Only SC=1, SC=2 and SC=3 are supported.")
-            solid_collisions[i] += solid_collision
-            for j in range(stencil.D):
-                forces_rhs[p * stencil.D + j] -= solid_collision * int(offset[j])
-
-    # Add solid collision to main assignments of collision rule
+    equilibrium_fluid = method.get_equilibrium_terms()
+    equilibrium_solid = []
+
+    # get equilibrium form object velocity
+    for eq in equilibrium_fluid:
+        eq_sol = eq
+        for i in range(stencil.D):
+            eq_sol = eq_sol.subs(sp.Symbol("u_" + str(i)),
+                                 psm_config.object_velocity_field.center(particle_per_cell_counter * stencil.D + i), )
+        equilibrium_solid.append(eq_sol)
+
+    # Build solid collision
+    for i, (eqFluid, eqSolid, f, offset) in enumerate(
+            zip(equilibrium_fluid, equilibrium_solid, pre_collision_pdf_symbols, stencil)):
+        inverse_direction_index = stencil.stencil_entries.index(stencil.inverse_stencil_entries[i])
+        if psm_config.solid_collision == 1:
+            solid_collision = (fraction_field.center(particle_per_cell_counter)
+                               * ((pre_collision_pdf_symbols[inverse_direction_index]
+                                   - equilibrium_fluid[inverse_direction_index]) - (f - eqSolid)))
+        elif psm_config.solid_collision == 2:
+            # TODO get relaxation rate vector from method and use the right relaxation rate [i]
+            solid_collision = (fraction_field.center(particle_per_cell_counter)
+                               * ((eqSolid - f) + (1.0 - method.relaxation_rates[0]) * (f - eqFluid)))
+        elif psm_config.solid_collision == 3:
+            solid_collision = (fraction_field.center(particle_per_cell_counter)
+                               * ((pre_collision_pdf_symbols[inverse_direction_index]
+                                   - equilibrium_solid[inverse_direction_index]) - (f - eqSolid)))
+        else:
+            raise ValueError("Only sc=1, sc=2 and sc=3 are supported.")
+
+        solid_collisions[i] += solid_collision
+
+    return solid_collisions
+
+
+def get_psm_force_from_solid_collision(solid_collisions, stencil, object_force_field, particle_per_cell_counter):
+    force_assignments = []
+    for d in range(stencil.D):
+        forces_rhs = 0
+        for sc, offset in zip(solid_collisions, stencil):
+            forces_rhs -= sc * int(offset[d])
+
+        force_assignments.append(Assignment(
+            object_force_field.center(particle_per_cell_counter * stencil.D + d), forces_rhs
+        ))
+    return AssignmentCollection(force_assignments)
+
+
+def replace_fraction_symbol_with_field(assignments, psm_config):
+    new_assignments = []
+    for ass in assignments:
+        rhs = ass.rhs.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0))
+        new_assignments.append(Assignment(ass.lhs, rhs))
+    return new_assignments
+
+
+def add_psm_solid_collision_to_collision_rule(collision_rule, lbm_config, particle_per_cell_counter):
+
+    method = collision_rule.method
+    solid_collisions = get_psm_solid_collision_term(collision_rule, lbm_config.psm_config, particle_per_cell_counter)
+    post_collision_pdf_symbols = method.post_collision_pdf_symbols
+
+    assignments = []
+    for sc, post in zip(solid_collisions, post_collision_pdf_symbols):
+        assignments.append(Assignment(post, post + sc))
+
+    if lbm_config.psm_config.object_force_field is not None:
+        assignments += get_psm_force_from_solid_collision(solid_collisions, method.stencil,
+                                                          lbm_config.psm_config.object_force_field,
+                                                          particle_per_cell_counter)
+
+    # exchanging rho with zeroth order moment symbol
+    if lbm_config.method in (Method.CENTRAL_MOMENT, Method.MONOMIAL_CUMULANT, Method.CUMULANT):
+        new_assignments = []
+        zeroth_moment_symbol = 'm_00' if lbm_config.stencil.D == 2 else 'm_000'
+        for ass in assignments:
+            new_assignments.append(ass.subs(sp.Symbol('rho'), sp.Symbol(zeroth_moment_symbol)))
+        assignments = new_assignments
+
+    collision_assignments = AssignmentCollection(assignments)
+    ac = LbmCollisionRule(method, collision_assignments, [],
+                          collision_rule.simplification_hints)
+    return ac
+
+
+def replace_by_psm_collision_rule(collision_rule, psm_config):
+
+    method = collision_rule.method
     collision_assignments = []
-    for main, sc in zip(collision_rule.main_assignments, solid_collisions):
-        collision_assignments.append(Assignment(main.lhs, main.rhs + sc))
-
-    # Add hydrodynamic force calculations to collision assignments if two-way coupling is used
-    # (i.e., force field is not None)
-    if psm_config.particle_force_field is not None:
-        for p in range(psm_config.MaxParticlesPerCell):
-            for i in range(stencil.D):
-                collision_assignments.append(
-                    Assignment(
-                        psm_config.particle_force_field.center(p * stencil.D + i),
-                        forces_rhs[p * stencil.D + i],
-                    )
-                )
+    solid_collisions = [0] * psm_config.max_particles_per_cell
+    for p in range(psm_config.max_particles_per_cell):
+        solid_collisions[p] = get_psm_solid_collision_term(collision_rule, psm_config, p)
+
+        if psm_config.object_force_field is not None:
+            collision_assignments += get_psm_force_from_solid_collision(solid_collisions[p], method.stencil,
+                                                                        psm_config.object_force_field, p)
+
+    for i, main in enumerate(collision_rule.main_assignments):
+        rhs = main.rhs
+        for p in range(psm_config.max_particles_per_cell):
+            rhs += solid_collisions[p][i]
+        collision_assignments.append(Assignment(main.lhs, rhs))
 
     collision_assignments = AssignmentCollection(collision_assignments)
-    ac = LbmCollisionRule(method, collision_assignments, collision_rule.subexpressions,
+    ac = LbmCollisionRule(method, replace_fraction_symbol_with_field(collision_assignments, psm_config),
+                          replace_fraction_symbol_with_field(collision_rule.subexpressions, psm_config),
                           collision_rule.simplification_hints)
     ac.topological_sort()
     return ac