From ad89e6993f724e30a47281bb5859fa081aa6fad6 Mon Sep 17 00:00:00 2001
From: Philipp Suffa <philipp.suffa@fau.de>
Date: Thu, 8 May 2025 12:03:17 +0200
Subject: [PATCH] Fixes error, where psm can not be build with cumulants [skip
 ci]

---
 src/lbmpy/creationfunctions.py         | 10 +++++-----
 src/lbmpy/partially_saturated_cells.py | 14 +++++++++-----
 2 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/src/lbmpy/creationfunctions.py b/src/lbmpy/creationfunctions.py
index ebac5b2f..96e164ea 100644
--- a/src/lbmpy/creationfunctions.py
+++ b/src/lbmpy/creationfunctions.py
@@ -702,6 +702,11 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
                                                      bulk_relaxation_rate=lbm_config.relaxation_rates[1],
                                                      limiter=cumulant_limiter)
 
+    if lbm_config.psm_config is not None:
+        if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None:
+            raise ValueError("Specify a fraction and object velocity field in the PSM Config")
+        collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config)
+
     if lbm_config.entropic:
         if lbm_config.subgrid_scale_model or lbm_config.cassons:
             raise ValueError("Choose either entropic, subgrid-scale or cassons")
@@ -756,11 +761,6 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N
     if lbm_config.fluctuating:
         add_fluctuations_to_collision_rule(collision_rule, **lbm_config.fluctuating)
 
-    if lbm_config.psm_config is not None:
-        if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None:
-            raise ValueError("Specify a fraction and object velocity field in the PSM Config")
-        collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config)
-
     if lbm_optimisation.cse_pdfs:
         from lbmpy.methods.momentbased.momentbasedsimplifications import cse_in_opposing_directions
         collision_rule = cse_in_opposing_directions(collision_rule)
diff --git a/src/lbmpy/partially_saturated_cells.py b/src/lbmpy/partially_saturated_cells.py
index 1a66f445..ea382194 100644
--- a/src/lbmpy/partially_saturated_cells.py
+++ b/src/lbmpy/partially_saturated_cells.py
@@ -103,8 +103,12 @@ def get_psm_force_from_solid_collision(solid_collisions, stencil, object_force_f
     return AssignmentCollection(force_assignments)
 
 
-def replace_fraction_symbol_with_field(psm_config, assignment):
-    return assignment.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0))
+def replace_fraction_symbol_with_field(assignments, psm_config):
+    new_assignments = []
+    for ass in assignments:
+        rhs = ass.rhs.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0))
+        new_assignments.append(Assignment(ass.lhs, rhs))
+    return new_assignments
 
 
 def add_psm_solid_collision_to_collision_rule(collision_rule, psm_config, particle_per_cell_counter):
@@ -140,14 +144,14 @@ def replace_by_psm_collision_rule(collision_rule, psm_config):
                                                                         psm_config.object_force_field, p)
 
     for i, main in enumerate(collision_rule.main_assignments):
-        rhs = replace_fraction_symbol_with_field(psm_config, main.rhs)
-
+        rhs =  main.rhs
         for p in range(psm_config.max_particles_per_cell):
             rhs += solid_collisions[p][i]
         collision_assignments.append(Assignment(main.lhs, rhs))
 
     collision_assignments = AssignmentCollection(collision_assignments)
-    ac = LbmCollisionRule(method, collision_assignments, collision_rule.subexpressions,
+    ac = LbmCollisionRule(method, replace_fraction_symbol_with_field(collision_assignments, psm_config),
+                          replace_fraction_symbol_with_field(collision_rule.subexpressions, psm_config),
                           collision_rule.simplification_hints)
     ac.topological_sort()
     return ac
-- 
GitLab