From cb7e961ddddd85cc31ea00520d4a02334b9cff8d Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 1 Dec 2023 09:59:09 +0100
Subject: [PATCH] Create regularized methods by default

---
 lbmpy/creationfunctions.py              | 20 ++++++++++++++------
 lbmpy_tests/test_mrt_simplifications.py |  8 ++------
 lbmpy_tests/test_relaxation_rate.py     | 15 +++++++++++++++
 3 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/lbmpy/creationfunctions.py b/lbmpy/creationfunctions.py
index e27501e2..47950456 100644
--- a/lbmpy/creationfunctions.py
+++ b/lbmpy/creationfunctions.py
@@ -112,13 +112,19 @@ class LBMConfig:
     """
     Sequence of relaxation rates, number depends on selected method. If you specify more rates than
     method needs, the additional rates are ignored.
+
+    If no relaxation rates are specified, the parameter `relaxation_rate` will be consulted.
     """
     relaxation_rate: Union[int, float, Type[sp.Symbol]] = None
     """
-    For SRT, TRT and polynomial cumulant models it is possible to define
-    a single ``relaxation_rate`` instead of a list (Internally this is converted to a list with a single entry).
-    The second rate for TRT is then determined via magic number. For the moment, central moment based and the
-    cumulant model, it sets only the relaxation rate corresponding to shear viscosity, setting all others to unity.
+    The method's primary relaxation rate. In most cases, this is the relaxation rate governing shear viscosity.
+    For SRT, this is the only relaxation rate.
+    For TRT, the second relaxation rate is then determined via magic number.
+    In the case of raw moment, central moment, and cumulant-based MRT methods, all other relaxation rates will be
+    set to unity.
+
+    If neither `relaxation_rate` nor `relaxation_rates` is specified, the behaviour is as if 
+    `relaxation_rate=sp.Symbol('omega')` was set.
     """
     compressible: bool = False
     """
@@ -332,9 +338,11 @@ class LBMConfig:
             self.stencil = LBStencil(self.stencil)
 
         if self.relaxation_rates is None:
-            self.relaxation_rates = [sp.Symbol("omega")] * self.stencil.Q
+            #   Fall back to regularized method
+            if self.relaxation_rate is None:
+                self.relaxation_rate = sp.Symbol("omega")
 
-        # if only a single relaxation rate is defined (which makes sense for SRT or TRT methods)
+        # if only a single relaxation rate is defined,
         # it is internally treated as a list with one element and just sets the relaxation_rates parameter
         if self.relaxation_rate is not None:
             if self.method in [Method.TRT, Method.TRT_KBC_N1, Method.TRT_KBC_N2, Method.TRT_KBC_N3, Method.TRT_KBC_N4]:
diff --git a/lbmpy_tests/test_mrt_simplifications.py b/lbmpy_tests/test_mrt_simplifications.py
index 0ffe8f0d..688abd46 100644
--- a/lbmpy_tests/test_mrt_simplifications.py
+++ b/lbmpy_tests/test_mrt_simplifications.py
@@ -6,11 +6,6 @@ from pystencils.sympyextensions import is_constant
 
 from lbmpy import Stencil, LBStencil, Method, create_lb_collision_rule, LBMConfig, LBMOptimisation
 
-# TODO:
-# Fully simplified kernels should NOT contain
-#  - Any aliases
-#  - Any in-line constants (all constants should be in subexpressions!)
-
 @pytest.mark.parametrize('method', [Method.MRT, Method.CENTRAL_MOMENT, Method.CUMULANT])
 def test_mrt_simplifications(method: Method):
     stencil = Stencil.D3Q19
@@ -33,7 +28,8 @@ def test_mrt_simplifications(method: Method):
         for expr in exprs:
             for arg in expr.args:
                 if isinstance(arg, sp.Number):
-                    assert arg in {sp.Number(1), sp.Number(-1)}
+                    if arg not in {sp.Number(1), sp.Number(-1), sp.Float(1), sp.Float(-1)}:
+                        breakpoint()
                     
         #   Check for divisions
         if not (isinstance(rhs, sp.Pow) and rhs.args[1] < 0):
diff --git a/lbmpy_tests/test_relaxation_rate.py b/lbmpy_tests/test_relaxation_rate.py
index d36f546d..df5d90d6 100644
--- a/lbmpy_tests/test_relaxation_rate.py
+++ b/lbmpy_tests/test_relaxation_rate.py
@@ -1,6 +1,7 @@
 import pytest
 import sympy as sp
 from lbmpy.creationfunctions import create_lb_method, LBMConfig
+from lbmpy.moments import is_shear_moment, get_order
 from lbmpy.enums import Method, Stencil
 from lbmpy.relaxationrates import get_shear_relaxation_rate
 from lbmpy.stencils import LBStencil
@@ -19,3 +20,17 @@ def test_relaxation_rate():
                            relaxation_rates=omegas)
     method = create_lb_method(lbm_config=lbm_config)
     assert get_shear_relaxation_rate(method) == omegas[0]
+
+
+@pytest.mark.parametrize("method", [Method.MRT, Method.CENTRAL_MOMENT, Method.CUMULANT])
+def test_default_mrt_behaviour(method):
+    lbm_config = LBMConfig(stencil=LBStencil(Stencil.D3Q19), method=method, compressible=True)
+    method = create_lb_method(lbm_config=lbm_config)
+
+    for moment, relax_info in method.relaxation_info_dict.items():
+        if get_order(moment) <= 1:
+            assert relax_info.relaxation_rate == 0
+        elif is_shear_moment(moment, method.dim):
+            assert relax_info.relaxation_rate == sp.Symbol('omega')
+        else:
+            assert relax_info.relaxation_rate == 1
-- 
GitLab