diff --git a/lbmpy/simplificationfactory.py b/lbmpy/simplificationfactory.py index b16fbcbfa7d03eeff433e5f552e524ee20522319..331f282d7f94b3225d283e0db622dc12bd24d06d 100644 --- a/lbmpy/simplificationfactory.py +++ b/lbmpy/simplificationfactory.py @@ -92,6 +92,7 @@ def _cumulant_space_simplification(split_inner_loop): s.add(expand_post_collision_central_moments) s.add(insert_aliases) s.add(insert_constants) + s.add(insert_aliases) s.add(add_subexpressions_for_divisions) s.add(add_subexpressions_for_constants) if split_inner_loop: diff --git a/lbmpy_tests/test_mrt_simplifications.py b/lbmpy_tests/test_mrt_simplifications.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffe8f0d4e6663b9b899bd666dc364e44eba8a42 --- /dev/null +++ b/lbmpy_tests/test_mrt_simplifications.py @@ -0,0 +1,42 @@ +import pytest + +import sympy as sp + +from pystencils.sympyextensions import is_constant + +from lbmpy import Stencil, LBStencil, Method, create_lb_collision_rule, LBMConfig, LBMOptimisation + +# TODO: +# Fully simplified kernels should NOT contain +# - Any aliases +# - Any in-line constants (all constants should be in subexpressions!) + +@pytest.mark.parametrize('method', [Method.MRT, Method.CENTRAL_MOMENT, Method.CUMULANT]) +def test_mrt_simplifications(method: Method): + stencil = Stencil.D3Q19 + lbm_config = LBMConfig(stencil=stencil, method=method, compressible=True) + lbm_opt = LBMOptimisation(simplification='auto') + + cr = create_lb_collision_rule(lbm_config=lbm_config, lbm_optimisation=lbm_opt) + + + for subexp in cr.subexpressions: + rhs = subexp.rhs + # Check for aliases + assert not isinstance(rhs, sp.Symbol) + + # Check for logarithms + assert not rhs.atoms(sp.log) + + # Check for nonextracted constant summands or factors + exprs = rhs.atoms(sp.Add, sp.Mul) + for expr in exprs: + for arg in expr.args: + if isinstance(arg, sp.Number): + assert arg in {sp.Number(1), sp.Number(-1)} + + # Check for divisions + if not (isinstance(rhs, sp.Pow) and rhs.args[1] < 0): + powers = rhs.atoms(sp.Pow) + for p in powers: + assert p.args[1] > 0