diff --git a/src/lbmpy/creationfunctions.py b/src/lbmpy/creationfunctions.py
index ebac5b2fa8ebd4ea8cbc36b52ef3254d94dae8d5..96e164eae381e5309c705938ff39c5304dc66406 100644
--- a/src/lbmpy/creationfunctions.py
+++ b/src/lbmpy/creationfunctions.py
@@ -702,6 +702,11 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
                                                      bulk_relaxation_rate=lbm_config.relaxation_rates[1],
                                                      limiter=cumulant_limiter)
 
+    if lbm_config.psm_config is not None:
+        if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None:
+            raise ValueError("Specify a fraction and object velocity field in the PSM Config")
+        collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config)
+
     if lbm_config.entropic:
         if lbm_config.subgrid_scale_model or lbm_config.cassons:
             raise ValueError("Choose either entropic, subgrid-scale or cassons")
@@ -756,11 +761,6 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
     if lbm_config.fluctuating:
         add_fluctuations_to_collision_rule(collision_rule, **lbm_config.fluctuating)
 
-    if lbm_config.psm_config is not None:
-        if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None:
-            raise ValueError("Specify a fraction and object velocity field in the PSM Config")
-        collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config)
-
     if lbm_optimisation.cse_pdfs:
         from lbmpy.methods.momentbased.momentbasedsimplifications import cse_in_opposing_directions
         collision_rule = cse_in_opposing_directions(collision_rule)
diff --git a/src/lbmpy/partially_saturated_cells.py b/src/lbmpy/partially_saturated_cells.py
index 1a66f4456203825faa5e198346d90c8eaf5fa41d..ea3821946e3bf16e6d60898109572dfb589a7099 100644
--- a/src/lbmpy/partially_saturated_cells.py
+++ b/src/lbmpy/partially_saturated_cells.py
@@ -103,8 +103,12 @@ def get_psm_force_from_solid_collision(solid_collisions, stencil, object_force_f
     return AssignmentCollection(force_assignments)
 
 
-def replace_fraction_symbol_with_field(psm_config, assignment):
-    return assignment.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0))
+def replace_fraction_symbol_with_field(assignments, psm_config):
+    new_assignments = []
+    for ass in assignments:
+        rhs = ass.rhs.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0))
+        new_assignments.append(Assignment(ass.lhs, rhs))
+    return new_assignments
 
 
 def add_psm_solid_collision_to_collision_rule(collision_rule, psm_config, particle_per_cell_counter):
@@ -140,14 +144,14 @@ def replace_by_psm_collision_rule(collision_rule, psm_config):
                                                                         psm_config.object_force_field, p)
 
     for i, main in enumerate(collision_rule.main_assignments):
-        rhs = replace_fraction_symbol_with_field(psm_config, main.rhs)
-
+        rhs =  main.rhs
         for p in range(psm_config.max_particles_per_cell):
             rhs += solid_collisions[p][i]
         collision_assignments.append(Assignment(main.lhs, rhs))
 
     collision_assignments = AssignmentCollection(collision_assignments)
-    ac = LbmCollisionRule(method, collision_assignments, collision_rule.subexpressions,
+    ac = LbmCollisionRule(method, replace_fraction_symbol_with_field(collision_assignments, psm_config),
+                          replace_fraction_symbol_with_field(collision_rule.subexpressions, psm_config),
                           collision_rule.simplification_hints)
     ac.topological_sort()
     return ac