diff --git a/src/lbmpy/creationfunctions.py b/src/lbmpy/creationfunctions.py index ebac5b2fa8ebd4ea8cbc36b52ef3254d94dae8d5..96e164eae381e5309c705938ff39c5304dc66406 100644 --- a/src/lbmpy/creationfunctions.py +++ b/src/lbmpy/creationfunctions.py @@ -702,6 +702,11 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N bulk_relaxation_rate=lbm_config.relaxation_rates[1], limiter=cumulant_limiter) + if lbm_config.psm_config is not None: + if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None: + raise ValueError("Specify a fraction and object velocity field in the PSM Config") + collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config) + if lbm_config.entropic: if lbm_config.subgrid_scale_model or lbm_config.cassons: raise ValueError("Choose either entropic, subgrid-scale or cassons") @@ -756,11 +761,6 @@ def create_lb_collision_rule(lb_method=None, lbm_config=None, lbm_optimisation=N if lbm_config.fluctuating: add_fluctuations_to_collision_rule(collision_rule, **lbm_config.fluctuating) - if lbm_config.psm_config is not None: - if lbm_config.psm_config.fraction_field is None or lbm_config.psm_config.object_velocity_field is None: - raise ValueError("Specify a fraction and object velocity field in the PSM Config") - collision_rule = replace_by_psm_collision_rule(collision_rule, lbm_config.psm_config) - if lbm_optimisation.cse_pdfs: from lbmpy.methods.momentbased.momentbasedsimplifications import cse_in_opposing_directions collision_rule = cse_in_opposing_directions(collision_rule) diff --git a/src/lbmpy/partially_saturated_cells.py b/src/lbmpy/partially_saturated_cells.py index 1a66f4456203825faa5e198346d90c8eaf5fa41d..ea3821946e3bf16e6d60898109572dfb589a7099 100644 --- a/src/lbmpy/partially_saturated_cells.py +++ b/src/lbmpy/partially_saturated_cells.py @@ -103,8 +103,12 @@ def get_psm_force_from_solid_collision(solid_collisions, stencil, object_force_f return AssignmentCollection(force_assignments) -def replace_fraction_symbol_with_field(psm_config, assignment): - return assignment.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0)) +def replace_fraction_symbol_with_field(assignments, psm_config): + new_assignments = [] + for ass in assignments: + rhs = ass.rhs.subs(psm_config.fraction_field_symbol, psm_config.fraction_field.center(0)) + new_assignments.append(Assignment(ass.lhs, rhs)) + return new_assignments def add_psm_solid_collision_to_collision_rule(collision_rule, psm_config, particle_per_cell_counter): @@ -140,14 +144,14 @@ def replace_by_psm_collision_rule(collision_rule, psm_config): psm_config.object_force_field, p) for i, main in enumerate(collision_rule.main_assignments): - rhs = replace_fraction_symbol_with_field(psm_config, main.rhs) - + rhs = main.rhs for p in range(psm_config.max_particles_per_cell): rhs += solid_collisions[p][i] collision_assignments.append(Assignment(main.lhs, rhs)) collision_assignments = AssignmentCollection(collision_assignments) - ac = LbmCollisionRule(method, collision_assignments, collision_rule.subexpressions, + ac = LbmCollisionRule(method, replace_fraction_symbol_with_field(collision_assignments, psm_config), + replace_fraction_symbol_with_field(collision_rule.subexpressions, psm_config), collision_rule.simplification_hints) ac.topological_sort() return ac