From b9073ad6922da9a2e1004221c3f10046d385e83b Mon Sep 17 00:00:00 2001
From: Philipp Suffa <philipp.suffa@fau.de>
Date: Tue, 13 May 2025 17:48:14 +0200
Subject: [PATCH] Fixes error, where psm can not be build with cumulants

---
 src/lbmpy/creationfunctions.py         |  2 +-
 src/lbmpy/custom_code_nodes.py         |  2 +-
 src/lbmpy/partially_saturated_cells.py | 20 +++++++++++++++-----
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/src/lbmpy/creationfunctions.py b/src/lbmpy/creationfunctions.py
index 96e164ea..5e253719 100644
--- a/src/lbmpy/creationfunctions.py
+++ b/src/lbmpy/creationfunctions.py
@@ -890,7 +890,7 @@ def create_psm_update_rule(lbm_config, lbm_optimisation):
 
     for p in range(lbm_config.psm_config.max_particles_per_cell):
 
-        psm_solid_collision = add_psm_solid_collision_to_collision_rule(lb_update_rule, lbm_config.psm_config, p)
+        psm_solid_collision = add_psm_solid_collision_to_collision_rule(lb_update_rule, lbm_config, p)
         psm_update_rule = create_lb_update_rule(
             collision_rule=psm_solid_collision, lbm_config=lbm_config, lbm_optimisation=lbm_optimisation)
 
diff --git a/src/lbmpy/custom_code_nodes.py b/src/lbmpy/custom_code_nodes.py
index 6540f382..2604a2b9 100644
--- a/src/lbmpy/custom_code_nodes.py
+++ b/src/lbmpy/custom_code_nodes.py
@@ -68,7 +68,7 @@ class LbmWeightInfo(CustomCodeNode):
         weights = [f"(({self.weights_symbol.dtype.c_name})({str(w.evalf(17))}))" for w in lb_method.weights]
         weights = ", ".join(weights)
         w_sym = self.weights_symbol
-        code = f"const {self.weights_symbol.dtype.c_name} {w_sym.name} [] = {{{ weights }}};\n"
+        code = f"const {self.weights_symbol.dtype.c_name} {w_sym.name} [] = {{{weights}}};\n"
         super(LbmWeightInfo, self).__init__(code, symbols_read=set(), symbols_defined={w_sym})
 
     def weight_of_direction(self, dir_idx, lb_method=None):
diff --git a/src/lbmpy/partially_saturated_cells.py b/src/lbmpy/partially_saturated_cells.py
index ea382194..35196e49 100644
--- a/src/lbmpy/partially_saturated_cells.py
+++ b/src/lbmpy/partially_saturated_cells.py
@@ -1,6 +1,7 @@
 import sympy as sp
 from dataclasses import dataclass
 
+from lbmpy.enums import Method
 from lbmpy.methods.abstractlbmethod import LbmCollisionRule
 from pystencils import Assignment, AssignmentCollection
 from pystencils.field import Field
@@ -111,19 +112,28 @@ def replace_fraction_symbol_with_field(assignments, psm_config):
     return new_assignments
 
 
-def add_psm_solid_collision_to_collision_rule(collision_rule, psm_config, particle_per_cell_counter):
+def add_psm_solid_collision_to_collision_rule(collision_rule, lbm_config, particle_per_cell_counter):
 
     method = collision_rule.method
-    solid_collisions = get_psm_solid_collision_term(collision_rule, psm_config, particle_per_cell_counter)
+    solid_collisions = get_psm_solid_collision_term(collision_rule, lbm_config.psm_config, particle_per_cell_counter)
     post_collision_pdf_symbols = method.post_collision_pdf_symbols
 
     assignments = []
     for sc, post in zip(solid_collisions, post_collision_pdf_symbols):
         assignments.append(Assignment(post, post + sc))
 
-    if psm_config.object_force_field is not None:
+    if lbm_config.psm_config.object_force_field is not None:
         assignments += get_psm_force_from_solid_collision(solid_collisions, method.stencil,
-                                                          psm_config.object_force_field, particle_per_cell_counter)
+                                                          lbm_config.psm_config.object_force_field,
+                                                          particle_per_cell_counter)
+
+    # exchanging rho with zeroth order moment symbol
+    if lbm_config.method in (Method.CENTRAL_MOMENT, Method.MONOMIAL_CUMULANT, Method.CUMULANT):
+        new_assignments = []
+        zeroth_moment_symbol = 'm_00' if lbm_config.stencil.D == 2 else 'm_000'
+        for ass in assignments:
+            new_assignments.append(ass.subs(sp.Symbol('rho'), sp.Symbol(zeroth_moment_symbol)))
+        assignments = new_assignments
 
     collision_assignments = AssignmentCollection(assignments)
     ac = LbmCollisionRule(method, collision_assignments, [],
@@ -144,7 +154,7 @@ def replace_by_psm_collision_rule(collision_rule, psm_config):
                                                                         psm_config.object_force_field, p)
 
     for i, main in enumerate(collision_rule.main_assignments):
-        rhs =  main.rhs
+        rhs = main.rhs
         for p in range(psm_config.max_particles_per_cell):
             rhs += solid_collisions[p][i]
         collision_assignments.append(Assignment(main.lhs, rhs))
-- 
GitLab