From 8df08f0a93635daadf9172dc5329c8a310f9b306 Mon Sep 17 00:00:00 2001
From: Christian Godenschwager <christian.godenschwager@fau.de>
Date: Fri, 27 Oct 2017 19:02:46 +0200
Subject: [PATCH] Make createLatticeBoltzmannMethod work with arbitrary
 stencils

---
 creationfunctions.py | 28 ++++++++++++++++------------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/creationfunctions.py b/creationfunctions.py
index 531549c4..dbfa5aca 100644
--- a/creationfunctions.py
+++ b/creationfunctions.py
@@ -154,7 +154,7 @@ from lbmpy.methods import createSRT, createTRT, createOrthogonalMRT, createKBCTy
 from lbmpy.methods.entropic import addIterativeEntropyCondition, addEntropyCondition
 from lbmpy.methods.entropic_eq_srt import createEntropicSRT
 from lbmpy.methods.relaxationrates import relaxationRateFromMagicNumber
-from lbmpy.stencils import getStencil
+from lbmpy.stencils import getStencil, stencilsHaveSameEntries
 import lbmpy.forcemodels as forceModels
 from lbmpy.simplificationfactory import createSimplificationStrategy
 from lbmpy.updatekernels import StreamPullTwoFieldsAccessor, PeriodicTwoFieldsAccessor, CollideOnlyInplaceAccessor, \
@@ -414,10 +414,14 @@ def _createLatticeBoltzmannUpdateRuleCached(stringParameters, forceModel, force,
 
 
 def createLatticeBoltzmannMethod(**params):
-    params, _ = updateWithDefaultParameters(params, {})
+    params, _ = updateWithDefaultParameters(params, {}, failOnUnknownParameter=False)
 
-    stencil = getStencil(params['stencil'])
-    dim = len(stencil[0])
+    if 'stencilList' in params:
+        stencilList = params['stencilList']
+    else:
+        stencilList = getStencil(params['stencil'])
+
+    dim = len(stencilList[0])
 
     forceIsZero = True
     for f_i in params['force']:
@@ -460,10 +464,10 @@ def createLatticeBoltzmannMethod(**params):
 
     if methodName.lower() == 'srt':
         assert len(relaxationRates) >= 1, "Not enough relaxation rates"
-        method = createSRT(stencil, relaxationRates[0], **commonParams)
+        method = createSRT(stencilList, relaxationRates[0], **commonParams)
     elif methodName.lower() == 'trt':
         assert len(relaxationRates) >= 2, "Not enough relaxation rates"
-        method = createTRT(stencil, relaxationRates[0], relaxationRates[1], **commonParams)
+        method = createTRT(stencilList, relaxationRates[0], relaxationRates[1], **commonParams)
     elif methodName.lower() == 'mrt':
         nextRelaxationRate = [0]
 
@@ -471,22 +475,22 @@ def createLatticeBoltzmannMethod(**params):
             res = relaxationRates[nextRelaxationRate[0]]
             nextRelaxationRate[0] += 1
             return res
-        method = createOrthogonalMRT(stencil, relaxationRateGetter, **commonParams)
+        method = createOrthogonalMRT(stencilList, relaxationRateGetter, **commonParams)
     elif methodName.lower() == 'mrt_raw':
-        method = createRawMRT(stencil, relaxationRates, **commonParams)
+        method = createRawMRT(stencilList, relaxationRates, **commonParams)
     elif methodName.lower() == 'mrt3':
-        method = createThreeRelaxationRateMRT(stencil, relaxationRates, **commonParams)
+        method = createThreeRelaxationRateMRT(stencilList, relaxationRates, **commonParams)
     elif methodName.lower().startswith('trt-kbc-n'):
-        if params['stencil'] == 'D2Q9':
+        if stencilsHaveSameEntries(stencilList, getStencil("D2Q9")):
             dim = 2
-        elif params['stencil'] == 'D3Q27':
+        elif stencilsHaveSameEntries(stencilList, getStencil("D3Q27")):
             dim = 3
         else:
             raise NotImplementedError("KBC type TRT methods can only be constructed for D2Q9 and D3Q27 stencils")
         methodNr = methodName[-1]
         method = createKBCTypeTRT(dim, relaxationRates[0], relaxationRates[1], 'KBC-N' + methodNr, **commonParams)
     elif methodName.lower() == 'entropic-srt':
-        method = createEntropicSRT(stencil, relaxationRates[0], forceModel, params['compressible'])
+        method = createEntropicSRT(stencilList, relaxationRates[0], forceModel, params['compressible'])
     else:
         raise ValueError("Unknown method %s" % (methodName,))
 
-- 
GitLab