Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • ravi.k.ayyala/lbmpy
  • brendan-waters/lbmpy
  • anirudh.jonnalagadda/lbmpy
  • jbadwaik/lbmpy
  • alexander.reinauer/lbmpy
  • itischler/lbmpy
  • he66coqe/lbmpy
  • ev81oxyl/lbmpy
  • Bindgen/lbmpy
  • da15siwa/lbmpy
  • holzer/lbmpy
  • RudolfWeeber/lbmpy
  • pycodegen/lbmpy
13 results
Select Git revision
Show changes
Showing
with 839 additions and 15 deletions
...@@ -185,6 +185,10 @@ def _predefined_stencils(stencil: str, ordering: str): ...@@ -185,6 +185,10 @@ def _predefined_stencils(stencil: str, ordering: str):
(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1), (1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1),
(1, 1, 1), (-1, -1, -1), (1, 1, -1), (-1, -1, 1), (1, 1, 1), (-1, -1, -1), (1, 1, -1), (-1, -1, 1),
(1, -1, 1), (-1, 1, -1), (-1, 1, 1), (1, -1, -1)), (1, -1, 1), (-1, 1, -1), (-1, 1, 1), (1, -1, -1)),
'fakhari': ((0, 0, 0),
(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1),
(1, 1, 1), (-1, -1, -1), (-1, 1, 1), (1, -1, -1),
(1, -1, 1), (-1, 1, -1), (1, 1, -1), (-1, -1, 1)),
}, },
'D3Q19': { 'D3Q19': {
'walberla': ((0, 0, 0), 'walberla': ((0, 0, 0),
...@@ -242,9 +246,15 @@ def _predefined_stencils(stencil: str, ordering: str): ...@@ -242,9 +246,15 @@ def _predefined_stencils(stencil: str, ordering: str):
(1, 0, -1), (-1, 0, 1), (0, 1, -1), (0, -1, 1), (1, 0, -1), (-1, 0, 1), (0, 1, -1), (0, -1, 1),
(1, 1, 1), (-1, -1, -1), (1, 1, -1), (-1, -1, 1), (1, -1, 1), (-1, 1, -1), (-1, 1, 1), (1, 1, 1), (-1, -1, -1), (1, 1, -1), (-1, -1, 1), (1, -1, 1), (-1, 1, -1), (-1, 1, 1),
(1, -1, -1)), (1, -1, -1)),
'braunschweig': ((0, 0, 0),
(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1),
(1, 1, 0), (-1, -1, 0), (1, -1, 0), (-1, 1, 0),
(1, 0, 1), (-1, 0, -1), (1, 0, -1), (-1, 0, 1),
(0, 1, 1), (0, -1, -1), (0, 1, -1), (0, -1, 1),
(1, 1, 1), (-1, 1, 1), (1, -1, 1), (-1, -1, 1), (1, 1, -1), (-1, 1, -1), (1, -1, -1),
(-1, -1, -1)),
} }
} }
try: try:
return predefined_stencils[stencil][ordering] return predefined_stencils[stencil][ordering]
except KeyError: except KeyError:
......
import sympy as sp import sympy as sp
from lbmpy.relaxationrates import get_shear_relaxation_rate from lbmpy.relaxationrates import get_shear_relaxation_rate, relaxation_rate_from_lattice_viscosity, \
lattice_viscosity_from_relaxation_rate
from lbmpy.utils import extract_shear_relaxation_rate, frobenius_norm, second_order_moment_tensor from lbmpy.utils import extract_shear_relaxation_rate, frobenius_norm, second_order_moment_tensor
from pystencils import Assignment from pystencils import Assignment
from lbmpy.enums import SubgridScaleModel
def add_smagorinsky_model(collision_rule, smagorinsky_constant, omega_output_field=None):
r""" Adds a smagorinsky model to a lattice Boltzmann collision rule. To add the Smagorinsky model to a LB scheme def add_sgs_model(collision_rule, subgrid_scale_model: SubgridScaleModel, model_constant=None, omega_output_field=None,
eddy_viscosity_field=None):
r""" Wrapper for SGS models to provide default constants and outsource SGS model handling from creation routines."""
if subgrid_scale_model == SubgridScaleModel.SMAGORINSKY:
model_constant = model_constant if model_constant else sp.Float(0.12)
return add_smagorinsky_model(collision_rule=collision_rule, smagorinsky_constant=model_constant,
omega_output_field=omega_output_field, eddy_viscosity_field=eddy_viscosity_field)
if subgrid_scale_model == SubgridScaleModel.QR:
model_constant = model_constant if model_constant else sp.Rational(1, 3)
return add_qr_model(collision_rule=collision_rule, qr_constant=model_constant,
omega_output_field=omega_output_field, eddy_viscosity_field=eddy_viscosity_field)
def add_smagorinsky_model(collision_rule, smagorinsky_constant, omega_output_field=None, eddy_viscosity_field=None):
r""" Adds a Smagorinsky model to a lattice Boltzmann collision rule. To add the Smagorinsky model to a LB scheme
one has to first compute the strain rate tensor $S_{ij}$ in each cell, and compute the turbulent one has to first compute the strain rate tensor $S_{ij}$ in each cell, and compute the turbulent
viscosity :math:`nu_t` from it. Then the local relaxation rate has to be adapted to match the total viscosity viscosity :math:`nu_t` from it. Then the local relaxation rate has to be adapted to match the total viscosity
:math `\nu_{total}` instead of the standard viscosity :math `\nu_0`. :math `\nu_{total}` instead of the standard viscosity :math `\nu_0`.
...@@ -34,7 +51,7 @@ def add_smagorinsky_model(collision_rule, smagorinsky_constant, omega_output_fie ...@@ -34,7 +51,7 @@ def add_smagorinsky_model(collision_rule, smagorinsky_constant, omega_output_fie
omega_s, found_symbolic_shear_relaxation = extract_shear_relaxation_rate(collision_rule, omega_s) omega_s, found_symbolic_shear_relaxation = extract_shear_relaxation_rate(collision_rule, omega_s)
if not found_symbolic_shear_relaxation: if not found_symbolic_shear_relaxation:
raise ValueError("For the smagorinsky model the shear relaxation rate has to be a symbol or it has to be " raise ValueError("For the Smagorinsky model the shear relaxation rate has to be a symbol or it has to be "
"assigned to a single equation in the assignment list") "assigned to a single equation in the assignment list")
f_neq = sp.Matrix(method.pre_collision_pdf_symbols) - method.get_equilibrium_terms() f_neq = sp.Matrix(method.pre_collision_pdf_symbols) - method.get_equilibrium_terms()
equilibrium = method.equilibrium_distribution equilibrium = method.equilibrium_distribution
...@@ -42,18 +59,91 @@ def add_smagorinsky_model(collision_rule, smagorinsky_constant, omega_output_fie ...@@ -42,18 +59,91 @@ def add_smagorinsky_model(collision_rule, smagorinsky_constant, omega_output_fie
tau_0 = sp.Symbol("tau_0_") tau_0 = sp.Symbol("tau_0_")
second_order_neq_moments = sp.Symbol("Pi") second_order_neq_moments = sp.Symbol("Pi")
rho = equilibrium.density if equilibrium.compressible else equilibrium.background_density rho = equilibrium.density if equilibrium.compressible else equilibrium.background_density
adapted_omega = sp.Symbol("smagorinsky_omega") adapted_omega = sp.Symbol("sgs_omega")
collision_rule = collision_rule.new_with_substitutions({omega_s: adapted_omega}, substitute_on_lhs=False) collision_rule = collision_rule.new_with_substitutions({omega_s: adapted_omega}, substitute_on_lhs=False)
# for derivation see notebook demo_custom_LES_model.ipynb # for derivation see notebook demo_custom_LES_model.ipynb
eqs = [Assignment(tau_0, 1 / omega_s), eqs = [Assignment(tau_0, sp.Float(1) / omega_s),
Assignment(second_order_neq_moments, Assignment(second_order_neq_moments,
frobenius_norm(second_order_moment_tensor(f_neq, method.stencil), factor=2) / rho), frobenius_norm(second_order_moment_tensor(f_neq, method.stencil), factor=2) / rho),
Assignment(adapted_omega, Assignment(adapted_omega,
2 / (tau_0 + sp.sqrt(18 * smagorinsky_constant ** 2 * second_order_neq_moments + tau_0 ** 2)))] sp.Float(2) / (tau_0 + sp.sqrt(
sp.Float(18) * smagorinsky_constant ** 2 * second_order_neq_moments + tau_0 ** 2)))]
collision_rule.subexpressions += eqs collision_rule.subexpressions += eqs
collision_rule.topological_sort(sort_subexpressions=True, sort_main_assignments=False) collision_rule.topological_sort(sort_subexpressions=True, sort_main_assignments=False)
if eddy_viscosity_field:
collision_rule.main_assignments.append(Assignment(
eddy_viscosity_field.center,
smagorinsky_constant ** 2 * sp.Rational(3, 2) * adapted_omega * second_order_neq_moments
))
if omega_output_field:
collision_rule.main_assignments.append(Assignment(omega_output_field.center, adapted_omega))
return collision_rule
def add_qr_model(collision_rule, qr_constant, omega_output_field=None, eddy_viscosity_field=None):
r""" Adds a QR model to a lattice Boltzmann collision rule, see :cite:`rozema15`.
WARNING : this subgrid-scale model is only defined for isotropic grids
"""
method = collision_rule.method
omega_s = get_shear_relaxation_rate(method)
omega_s, found_symbolic_shear_relaxation = extract_shear_relaxation_rate(collision_rule, omega_s)
if not found_symbolic_shear_relaxation:
raise ValueError("For the QR model the shear relaxation rate has to be a symbol or it has to be "
"assigned to a single equation in the assignment list")
f_neq = sp.Matrix(method.pre_collision_pdf_symbols) - method.get_equilibrium_terms()
equilibrium = method.equilibrium_distribution
nu_0, nu_e = sp.symbols("qr_nu_0 qr_nu_e")
c_pi_s = sp.Symbol("qr_c_pi")
rho = equilibrium.density if equilibrium.compressible else equilibrium.background_density
adapted_omega = sp.Symbol("sgs_omega")
stencil = method.stencil
pi = second_order_moment_tensor(f_neq, stencil)
r_prime = sp.Symbol("qr_r_prime")
q_prime = sp.Symbol("qr_q_prime")
c_pi = qr_constant * sp.Piecewise(
(r_prime, sp.StrictGreaterThan(r_prime, sp.Float(0))),
(sp.Float(0), True)
) / q_prime
collision_rule = collision_rule.new_with_substitutions({omega_s: adapted_omega}, substitute_on_lhs=False)
if stencil.D == 2:
nu_e_assignments = [Assignment(nu_e, c_pi_s)]
elif stencil.D == 3:
base_viscosity = sp.Symbol("qr_base_viscosity")
nu_e_assignments = [
Assignment(base_viscosity, sp.Float(6) * nu_0 + sp.Float(1)),
Assignment(nu_e, (-base_viscosity + sp.sqrt(base_viscosity ** 2 + sp.Float(72) * c_pi_s / rho))
/ sp.Float(12))
]
else:
raise ValueError("QR-model is only defined for 2- or 3-dimensional flows")
matrix_entries = sp.Matrix(stencil.D, stencil.D, sp.symbols(f"sgs_qr_pi:{stencil.D ** 2}"))
eqs = [Assignment(nu_0, lattice_viscosity_from_relaxation_rate(omega_s)),
*[Assignment(matrix_entries[i], pi[i]) for i in range(stencil.D ** 2)],
Assignment(r_prime, sp.Float(-1) ** (stencil.D + 1) * matrix_entries.det()),
Assignment(q_prime, sp.Rational(1, 2) * (matrix_entries * matrix_entries).trace()),
Assignment(c_pi_s, c_pi),
*nu_e_assignments,
Assignment(adapted_omega, relaxation_rate_from_lattice_viscosity(nu_0 + nu_e))]
collision_rule.subexpressions += eqs
collision_rule.topological_sort(sort_subexpressions=True, sort_main_assignments=False)
if eddy_viscosity_field:
collision_rule.main_assignments.append(Assignment(eddy_viscosity_field.center, nu_e))
if omega_output_field: if omega_output_field:
collision_rule.main_assignments.append(Assignment(omega_output_field.center, adapted_omega)) collision_rule.main_assignments.append(Assignment(omega_output_field.center, adapted_omega))
......
...@@ -10,7 +10,8 @@ from pystencils.sympyextensions import fast_subs ...@@ -10,7 +10,8 @@ from pystencils.sympyextensions import fast_subs
# -------------------------------------------- LBM Kernel Creation ----------------------------------------------------- # -------------------------------------------- LBM Kernel Creation -----------------------------------------------------
def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=StreamPullTwoFieldsAccessor()): def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=StreamPullTwoFieldsAccessor(),
data_type=None):
"""Replaces the pre- and post collision symbols in the collision rule by field accesses. """Replaces the pre- and post collision symbols in the collision rule by field accesses.
Args: Args:
...@@ -19,6 +20,7 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream ...@@ -19,6 +20,7 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream
dst_field: field used for writing pdf values if accessor.is_inplace this parameter is ignored dst_field: field used for writing pdf values if accessor.is_inplace this parameter is ignored
accessor: instance of PdfFieldAccessor, defining where to read and write values accessor: instance of PdfFieldAccessor, defining where to read and write values
to create e.g. a fused stream-collide kernel See 'fieldaccess.PdfFieldAccessor' to create e.g. a fused stream-collide kernel See 'fieldaccess.PdfFieldAccessor'
data_type: If a datatype is given the field reads are isolated and written to TypedSymbols of type data_type
Returns: Returns:
LbmCollisionRule where pre- and post collision symbols have been replaced LbmCollisionRule where pre- and post collision symbols have been replaced
...@@ -49,8 +51,9 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream ...@@ -49,8 +51,9 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream
new_split_groups.append([fast_subs(e, substitutions) for e in split_group]) new_split_groups.append([fast_subs(e, substitutions) for e in split_group])
result.simplification_hints['split_groups'] = new_split_groups result.simplification_hints['split_groups'] = new_split_groups
if accessor.is_inplace: if accessor.is_inplace or (data_type is not None and src_field.dtype != data_type):
result = add_subexpressions_for_field_reads(result, subexpressions=True, main_assignments=True) result = add_subexpressions_for_field_reads(result, subexpressions=True, main_assignments=True,
data_type=data_type)
return result return result
...@@ -117,6 +120,26 @@ def create_stream_pull_with_output_kernel(lb_method, src_field, dst_field=None, ...@@ -117,6 +120,26 @@ def create_stream_pull_with_output_kernel(lb_method, src_field, dst_field=None,
simplification_hints=output_eq_collection.simplification_hints) simplification_hints=output_eq_collection.simplification_hints)
def create_copy_kernel(stencil, src_field, dst_field, accessor=StreamPullTwoFieldsAccessor()):
"""Creates a copy kernel, which can be used to transfer information from src to dst field.
Args:
stencil: lattice Boltzmann stencil which is used in the form of a tuple of tuples
src_field: field used for reading pdf values
dst_field: field used for writing pdf values
accessor: instance of PdfFieldAccessor, defining where to read and write values
to create e.g. a fused stream-collide kernel See 'fieldaccess.PdfFieldAccessor'
Returns:
AssignmentCollection of a copy update rule
"""
temporary_symbols = sp.symbols(f'copied:{stencil.Q}')
subexpressions = [Assignment(tmp, acc) for tmp, acc in zip(temporary_symbols, accessor.write(src_field, stencil))]
main_assignments = [Assignment(acc, tmp) for acc, tmp in zip(accessor.write(dst_field, stencil), temporary_symbols)]
return AssignmentCollection(main_assignments, subexpressions=subexpressions)
# ---------------------------------- Pdf array creation for various layouts -------------------------------------------- # ---------------------------------- Pdf array creation for various layouts --------------------------------------------
......
File moved
...@@ -5,36 +5,52 @@ import numpy as np ...@@ -5,36 +5,52 @@ import numpy as np
from lbmpy.stencils import LBStencil from lbmpy.stencils import LBStencil
from pystencils.slicing import get_slice_before_ghost_layer, get_ghost_region_slice from pystencils.slicing import get_slice_before_ghost_layer, get_ghost_region_slice
from lbmpy.creationfunctions import create_lb_update_rule, LBMConfig, LBMOptimisation from lbmpy.creationfunctions import create_lb_update_rule, LBMConfig, LBMOptimisation
from lbmpy.advanced_streaming.communication import get_communication_slices, _fix_length_one_slices, \ from lbmpy.advanced_streaming.communication import (
LBMPeriodicityHandling get_communication_slices,
_fix_length_one_slices,
LBMPeriodicityHandling,
periodic_pdf_gpu_copy_kernel,
)
from lbmpy.advanced_streaming.utility import streaming_patterns, Timestep from lbmpy.advanced_streaming.utility import streaming_patterns, Timestep
from lbmpy.enums import Stencil from lbmpy.enums import Stencil
import pytest import pytest
@pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]) @pytest.mark.parametrize(
@pytest.mark.parametrize('streaming_pattern', streaming_patterns) "stencil", [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]
@pytest.mark.parametrize('timestep', [Timestep.EVEN, Timestep.ODD]) )
@pytest.mark.parametrize("streaming_pattern", streaming_patterns)
@pytest.mark.parametrize("timestep", [Timestep.EVEN, Timestep.ODD])
def test_slices_not_empty(stencil, streaming_pattern, timestep): def test_slices_not_empty(stencil, streaming_pattern, timestep):
stencil = LBStencil(stencil) stencil = LBStencil(stencil)
arr = np.zeros((4,) * stencil.D + (stencil.Q,)) arr = np.zeros((4,) * stencil.D + (stencil.Q,))
slices = get_communication_slices(stencil, streaming_pattern=streaming_pattern, prev_timestep=timestep, slices = get_communication_slices(
ghost_layers=1) stencil,
streaming_pattern=streaming_pattern,
prev_timestep=timestep,
ghost_layers=1,
)
for _, slices_list in slices.items(): for _, slices_list in slices.items():
for src, dst in slices_list: for src, dst in slices_list:
assert all(s != 0 for s in arr[src].shape) assert all(s != 0 for s in arr[src].shape)
assert all(s != 0 for s in arr[dst].shape) assert all(s != 0 for s in arr[dst].shape)
@pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]) @pytest.mark.parametrize(
@pytest.mark.parametrize('streaming_pattern', streaming_patterns) "stencil", [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]
@pytest.mark.parametrize('timestep', [Timestep.EVEN, Timestep.ODD]) )
@pytest.mark.parametrize("streaming_pattern", streaming_patterns)
@pytest.mark.parametrize("timestep", [Timestep.EVEN, Timestep.ODD])
def test_src_dst_same_shape(stencil, streaming_pattern, timestep): def test_src_dst_same_shape(stencil, streaming_pattern, timestep):
stencil = LBStencil(stencil) stencil = LBStencil(stencil)
arr = np.zeros((4,) * stencil.D + (stencil.Q,)) arr = np.zeros((4,) * stencil.D + (stencil.Q,))
slices = get_communication_slices(stencil, streaming_pattern=streaming_pattern, prev_timestep=timestep, slices = get_communication_slices(
ghost_layers=1) stencil,
streaming_pattern=streaming_pattern,
prev_timestep=timestep,
ghost_layers=1,
)
for _, slices_list in slices.items(): for _, slices_list in slices.items():
for src, dst in slices_list: for src, dst in slices_list:
src_shape = arr[src].shape src_shape = arr[src].shape
...@@ -42,12 +58,15 @@ def test_src_dst_same_shape(stencil, streaming_pattern, timestep): ...@@ -42,12 +58,15 @@ def test_src_dst_same_shape(stencil, streaming_pattern, timestep):
assert src_shape == dst_shape assert src_shape == dst_shape
@pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]) @pytest.mark.parametrize(
"stencil", [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]
)
def test_pull_communication_slices(stencil): def test_pull_communication_slices(stencil):
stencil = LBStencil(stencil) stencil = LBStencil(stencil)
slices = get_communication_slices( slices = get_communication_slices(
stencil, streaming_pattern='pull', prev_timestep=Timestep.BOTH, ghost_layers=1) stencil, streaming_pattern="pull", prev_timestep=Timestep.BOTH, ghost_layers=1
)
for i, d in enumerate(stencil): for i, d in enumerate(stencil):
if i == 0: if i == 0:
continue continue
...@@ -58,21 +77,115 @@ def test_pull_communication_slices(stencil): ...@@ -58,21 +77,115 @@ def test_pull_communication_slices(stencil):
dst = s[1][:-1] dst = s[1][:-1]
break break
inner_slice = _fix_length_one_slices(get_slice_before_ghost_layer(d, ghost_layers=1)) inner_slice = _fix_length_one_slices(
get_slice_before_ghost_layer(d, ghost_layers=1)
)
inv_dir = (-e for e in d) inv_dir = (-e for e in d)
gl_slice = _fix_length_one_slices(get_ghost_region_slice(inv_dir, ghost_layers=1)) gl_slice = _fix_length_one_slices(
get_ghost_region_slice(inv_dir, ghost_layers=1)
)
assert src == inner_slice assert src == inner_slice
assert dst == gl_slice assert dst == gl_slice
@pytest.mark.parametrize('stencil_name', [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]) @pytest.mark.parametrize("direction", LBStencil(Stencil.D3Q27).stencil_entries)
@pytest.mark.parametrize("pull", [False, True])
def test_gpu_comm_kernels(direction: tuple, pull: bool):
pytest.importorskip("cupy")
stencil = LBStencil(Stencil.D3Q27)
inv_dir = stencil[stencil.inverse_index(direction)]
target = ps.Target.GPU
domain_size = (4,) * stencil.D
dh: ps.datahandling.SerialDataHandling = ps.create_data_handling(
domain_size,
periodicity=(True,) * stencil.D,
parallel=False,
default_target=target,
)
field = dh.add_array("field", values_per_cell=2)
if pull:
dst_slice = get_ghost_region_slice(inv_dir)
src_slice = get_slice_before_ghost_layer(direction)
else:
dst_slice = get_slice_before_ghost_layer(direction)
src_slice = get_ghost_region_slice(inv_dir)
src_slice += (1,)
dst_slice += (1,)
kernel = periodic_pdf_gpu_copy_kernel(field, src_slice, dst_slice)
dh.cpu_arrays[field.name][src_slice] = 42.0
dh.all_to_gpu()
dh.run_kernel(kernel)
dh.all_to_cpu()
np.testing.assert_equal(dh.cpu_arrays[field.name][dst_slice], 42.0)
@pytest.mark.parametrize("stencil", [Stencil.D2Q9, Stencil.D3Q19])
@pytest.mark.parametrize("streaming_pattern", streaming_patterns)
def test_direct_copy_and_kernels_equivalence(stencil: Stencil, streaming_pattern: str):
pytest.importorskip("cupy")
target = ps.Target.GPU
stencil = LBStencil(stencil)
domain_size = (4,) * stencil.D
dh: ps.datahandling.SerialDataHandling = ps.create_data_handling(
domain_size,
periodicity=(True,) * stencil.D,
parallel=False,
default_target=target,
)
pdfs_a = dh.add_array("pdfs_a", values_per_cell=stencil.Q)
pdfs_b = dh.add_array("pdfs_b", values_per_cell=stencil.Q)
dh.fill(pdfs_a.name, 0.0, ghost_layers=True)
dh.fill(pdfs_b.name, 0.0, ghost_layers=True)
for q in range(stencil.Q):
sl = ps.make_slice[:4, :4, q] if stencil.D == 2 else ps.make_slice[:4, :4, :4, q]
dh.cpu_arrays[pdfs_a.name][sl] = q
dh.cpu_arrays[pdfs_b.name][sl] = q
dh.all_to_gpu()
direct_copy = LBMPeriodicityHandling(stencil, dh, pdfs_a.name, streaming_pattern, cupy_direct_copy=True)
kernels_copy = LBMPeriodicityHandling(stencil, dh, pdfs_b.name, streaming_pattern, cupy_direct_copy=False)
direct_copy(Timestep.EVEN)
kernels_copy(Timestep.EVEN)
dh.all_to_cpu()
np.testing.assert_equal(
dh.cpu_arrays[pdfs_a.name],
dh.cpu_arrays[pdfs_b.name]
)
@pytest.mark.parametrize(
"stencil_name", [Stencil.D2Q9, Stencil.D3Q15, Stencil.D3Q19, Stencil.D3Q27]
)
def test_optimised_and_full_communication_equivalence(stencil_name): def test_optimised_and_full_communication_equivalence(stencil_name):
target = ps.Target.CPU target = ps.Target.CPU
stencil = LBStencil(stencil_name) stencil = LBStencil(stencil_name)
domain_size = (4, ) * stencil.D domain_size = (4,) * stencil.D
dh = ps.create_data_handling(domain_size, periodicity=(True, ) * stencil.D, dh = ps.create_data_handling(
parallel=False, default_target=target) domain_size,
periodicity=(True,) * stencil.D,
parallel=False,
default_target=target,
)
pdf = dh.add_array("pdf", values_per_cell=len(stencil), dtype=np.int64) pdf = dh.add_array("pdf", values_per_cell=len(stencil), dtype=np.int64)
dh.fill("pdf", 0, ghost_layers=True) dh.fill("pdf", 0, ghost_layers=True)
...@@ -82,9 +195,9 @@ def test_optimised_and_full_communication_equivalence(stencil_name): ...@@ -82,9 +195,9 @@ def test_optimised_and_full_communication_equivalence(stencil_name):
gl = dh.ghost_layers_of_field("pdf") gl = dh.ghost_layers_of_field("pdf")
num = 0 num = 0
for idx, x in np.ndenumerate(dh.cpu_arrays['pdf']): for idx, x in np.ndenumerate(dh.cpu_arrays["pdf"]):
dh.cpu_arrays['pdf'][idx] = num dh.cpu_arrays["pdf"][idx] = num
dh.cpu_arrays['pdf_tmp'][idx] = num dh.cpu_arrays["pdf_tmp"][idx] = num
num += 1 num += 1
lbm_config = LBMConfig(stencil=stencil, kernel_type="stream_pull_only") lbm_config = LBMConfig(stencil=stencil, kernel_type="stream_pull_only")
...@@ -95,21 +208,27 @@ def test_optimised_and_full_communication_equivalence(stencil_name): ...@@ -95,21 +208,27 @@ def test_optimised_and_full_communication_equivalence(stencil_name):
ast = ps.create_kernel(ac, config=config) ast = ps.create_kernel(ac, config=config)
stream = ast.compile() stream = ast.compile()
full_communication = dh.synchronization_function(pdf.name, target=dh.default_target, optimization={"openmp": True}) full_communication = dh.synchronization_function(
pdf.name, target=dh.default_target, optimization={"openmp": True}
)
full_communication() full_communication()
dh.run_kernel(stream) dh.run_kernel(stream)
dh.swap("pdf", "pdf_tmp") dh.swap("pdf", "pdf_tmp")
pdf_full_communication = np.copy(dh.cpu_arrays['pdf']) pdf_full_communication = np.copy(dh.cpu_arrays["pdf"])
num = 0 num = 0
for idx, x in np.ndenumerate(dh.cpu_arrays['pdf']): for idx, x in np.ndenumerate(dh.cpu_arrays["pdf"]):
dh.cpu_arrays['pdf'][idx] = num dh.cpu_arrays["pdf"][idx] = num
dh.cpu_arrays['pdf_tmp'][idx] = num dh.cpu_arrays["pdf_tmp"][idx] = num
num += 1 num += 1
optimised_communication = LBMPeriodicityHandling(stencil=stencil, data_handling=dh, pdf_field_name=pdf.name, optimised_communication = LBMPeriodicityHandling(
streaming_pattern='pull') stencil=stencil,
data_handling=dh,
pdf_field_name=pdf.name,
streaming_pattern="pull",
)
optimised_communication() optimised_communication()
dh.run_kernel(stream) dh.run_kernel(stream)
dh.swap("pdf", "pdf_tmp") dh.swap("pdf", "pdf_tmp")
...@@ -119,9 +238,14 @@ def test_optimised_and_full_communication_equivalence(stencil_name): ...@@ -119,9 +238,14 @@ def test_optimised_and_full_communication_equivalence(stencil_name):
for j in range(gl, domain_size[1]): for j in range(gl, domain_size[1]):
for k in range(gl, domain_size[2]): for k in range(gl, domain_size[2]):
for f in range(len(stencil)): for f in range(len(stencil)):
assert dh.cpu_arrays['pdf'][i, j, k, f] == pdf_full_communication[i, j, k, f], print(f) assert (
dh.cpu_arrays["pdf"][i, j, k, f]
== pdf_full_communication[i, j, k, f]
), print(f)
else: else:
for i in range(gl, domain_size[0]): for i in range(gl, domain_size[0]):
for j in range(gl, domain_size[1]): for j in range(gl, domain_size[1]):
for f in range(len(stencil)): for f in range(len(stencil)):
assert dh.cpu_arrays['pdf'][i, j, f] == pdf_full_communication[i, j, f] assert (
dh.cpu_arrays["pdf"][i, j, f] == pdf_full_communication[i, j, f]
)
...@@ -20,7 +20,7 @@ all_results = dict() ...@@ -20,7 +20,7 @@ all_results = dict()
targets = [Target.CPU] targets = [Target.CPU]
try: try:
import pycuda.autoinit import cupy
targets += [Target.GPU] targets += [Target.GPU]
except Exception: except Exception:
pass pass
......
...@@ -21,7 +21,7 @@ all_results = dict() ...@@ -21,7 +21,7 @@ all_results = dict()
targets = [Target.CPU] targets = [Target.CPU]
try: try:
import pycuda.autoinit import cupy
targets += [Target.GPU] targets += [Target.GPU]
except Exception: except Exception:
pass pass
......
{"indexes": {"simulationresult": {"pk": {"key": "pk", "id": "76f7f2660bd4475b90697716ef655a46"}}}, "store_class": "transactional", "index_class": "transactional", "index_store_class": "basic", "serializer_class": "json", "autocommit": false, "version": "0.2.12"}
\ No newline at end of file
...@@ -16,7 +16,7 @@ from pystencils import create_kernel, create_data_handling, Assignment, Target, ...@@ -16,7 +16,7 @@ from pystencils import create_kernel, create_data_handling, Assignment, Target,
from pystencils.slicing import slice_from_direction, get_slice_before_ghost_layer from pystencils.slicing import slice_from_direction, get_slice_before_ghost_layer
def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps): def flow_around_sphere(stencil, galilean_correction, fourth_order_correction, L_LU, total_steps):
if galilean_correction and stencil.Q != 27: if galilean_correction and stencil.Q != 27:
pytest.skip() pytest.skip()
...@@ -38,7 +38,7 @@ def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps): ...@@ -38,7 +38,7 @@ def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps):
lbm_config = LBMConfig(stencil=stencil, method=Method.CUMULANT, relaxation_rate=omega_v, lbm_config = LBMConfig(stencil=stencil, method=Method.CUMULANT, relaxation_rate=omega_v,
compressible=True, compressible=True,
galilean_correction=galilean_correction) galilean_correction=galilean_correction, fourth_order_correction=fourth_order_correction)
lbm_opt = LBMOptimisation(pre_simplification=True) lbm_opt = LBMOptimisation(pre_simplification=True)
config = CreateKernelConfig(target=target) config = CreateKernelConfig(target=target)
...@@ -58,7 +58,13 @@ def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps): ...@@ -58,7 +58,13 @@ def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps):
boundary_assignments) boundary_assignments)
iter_slice = get_slice_before_ghost_layer((1,) + (0,) * (stencil.D - 1)) iter_slice = get_slice_before_ghost_layer((1,) + (0,) * (stencil.D - 1))
extrapolation_ast = create_kernel( extrapolation_ast = create_kernel(
boundary_assignments, config=CreateKernelConfig(iteration_slice=iter_slice, ghost_layers=1, target=target)) boundary_assignments,
config=CreateKernelConfig(
iteration_slice=iter_slice,
ghost_layers=1,
target=target,
skip_independence_check=True)
)
return extrapolation_ast.compile() return extrapolation_ast.compile()
dh = create_data_handling(channel_size, periodicity=False, default_layout='fzyx', default_target=target) dh = create_data_handling(channel_size, periodicity=False, default_layout='fzyx', default_target=target)
...@@ -135,14 +141,22 @@ def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps): ...@@ -135,14 +141,22 @@ def flow_around_sphere(stencil, galilean_correction, L_LU, total_steps):
@pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q19, Stencil.D3Q27]) @pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q19, Stencil.D3Q27])
@pytest.mark.parametrize('galilean_correction', [False, True]) @pytest.mark.parametrize('galilean_correction', [False, True])
def test_flow_around_sphere_short(stencil, galilean_correction): @pytest.mark.parametrize('fourth_order_correction', [False, True])
pytest.importorskip('pycuda') def test_flow_around_sphere_short(stencil, galilean_correction, fourth_order_correction):
flow_around_sphere(LBStencil(stencil), galilean_correction, 5, 200) pytest.importorskip('cupy')
stencil = LBStencil(stencil)
if fourth_order_correction and stencil.Q != 27:
pytest.skip("Fourth-order correction only defined for D3Q27 stencil.")
flow_around_sphere(stencil, galilean_correction, fourth_order_correction, 5, 200)
@pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q19, Stencil.D3Q27]) @pytest.mark.parametrize('stencil', [Stencil.D2Q9, Stencil.D3Q19, Stencil.D3Q27])
@pytest.mark.parametrize('galilean_correction', [False, True]) @pytest.mark.parametrize('galilean_correction', [False, True])
@pytest.mark.parametrize('fourth_order_correction', [False, 0.01])
@pytest.mark.longrun @pytest.mark.longrun
def test_flow_around_sphere_long(stencil, galilean_correction): def test_flow_around_sphere_long(stencil, galilean_correction, fourth_order_correction):
pytest.importorskip('pycuda') pytest.importorskip('cupy')
flow_around_sphere(LBStencil(stencil), galilean_correction, 20, 3000) stencil = LBStencil(stencil)
if fourth_order_correction and stencil.Q != 27:
pytest.skip("Fourth-order correction only defined for D3Q27 stencil.")
flow_around_sphere(stencil, galilean_correction, fourth_order_correction, 20, 3000)
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
from dataclasses import replace
import sympy as sp
import numpy as np
from numpy.testing import assert_allclose
from pystencils.session import *
from pystencils import Target
from lbmpy.session import *
from lbmpy.macroscopic_value_kernels import macroscopic_values_getter, macroscopic_values_setter
```
%% Cell type:code id: tags:
``` python
try:
import cupy
except ImportError:
cupy = None
target = ps.Target.CPU
print('No cupy installed')
if cupy:
target = ps.Target.GPU
```
%% Cell type:markdown id: tags:
## Pipe Flow Scenario
%% Cell type:code id: tags:
``` python
class PeriodicPipeFlow:
def __init__(self, lbm_config, lbm_optimisation, config):
wall_boundary = NoSlip()
self.stencil = lbm_config.stencil
self.streaming_pattern = lbm_config.streaming_pattern
self.target = config.target
self.gpu = self.target == Target.GPU
# Stencil
self.q = stencil.Q
self.dim = stencil.D
# Streaming
self.inplace = is_inplace(self.streaming_pattern)
self.timesteps = get_timesteps(self.streaming_pattern)
self.zeroth_timestep = self.timesteps[0]
# Domain, Data Handling and PDF fields
self.pipe_length = 60
self.pipe_radius = 15
self.domain_size = (self.pipe_length, ) + (2 * self.pipe_radius,) * (self.dim - 1)
self.periodicity = (True, ) + (False, ) * (self.dim - 1)
force = (0.0001, ) + (0.0,) * (self.dim - 1)
self.force = lbm_config.force
self.dh = create_data_handling(domain_size=self.domain_size,
periodicity=self.periodicity, default_target=self.target)
self.pdfs = self.dh.add_array('pdfs', self.q)
if not self.inplace:
self.pdfs_tmp = self.dh.add_array_like('pdfs_tmp', self.pdfs.name)
lbm_optimisation = replace(lbm_optimisation, symbolic_field=self.pdfs)
if not self.inplace:
lbm_optimisation = replace(lbm_optimisation, symbolic_temporary_field=self.pdfs_tmp)
self.lb_collision = create_lb_collision_rule(lbm_config=lbm_config, lbm_optimisation=lbm_optimisation)
self.lb_method = self.lb_collision.method
self.lb_kernels = []
for t in self.timesteps:
lbm_config = replace(lbm_config, timestep=t, collision_rule=self.lb_collision)
self.lb_kernels.append(create_lb_function(lbm_config=lbm_config,
lbm_optimisation=lbm_optimisation,
config=config))
# Macroscopic Values
self.density = 1.0
self.density_field = self.dh.add_array('rho', 1)
u_x = 0.0
self.velocity = (u_x,) * self.dim
self.velocity_field = self.dh.add_array('u', self.dim)
single_gl_config = replace(config, ghost_layers=1)
setter = macroscopic_values_setter(
self.lb_method, self.density, self.velocity, self.pdfs,
streaming_pattern=self.streaming_pattern, previous_timestep=self.zeroth_timestep)
self.init_kernel = create_kernel(setter, config=single_gl_config).compile()
self.getter_kernels = []
for t in self.timesteps:
getter = macroscopic_values_getter(
self.lb_method, self.density_field, self.velocity_field, self.pdfs,
streaming_pattern=self.streaming_pattern, previous_timestep=t)
self.getter_kernels.append(create_kernel(getter, config=single_gl_config).compile())
# Periodicity
self.periodicity_handler = LBMPeriodicityHandling(
self.stencil, self.dh, self.pdfs.name, streaming_pattern=self.streaming_pattern)
# Boundary Handling
self.wall = wall_boundary
self.bh = LatticeBoltzmannBoundaryHandling(
self.lb_method, self.dh, self.pdfs.name,
streaming_pattern=self.streaming_pattern, target=self.target)
self.bh.set_boundary(boundary_obj=self.wall, mask_callback=self.mask_callback)
self.current_timestep = self.zeroth_timestep
def mask_callback(self, x, y, z=None):
y = y - self.pipe_radius
z = z - self.pipe_radius if z is not None else 0
return np.sqrt(y**2 + z**2) >= self.pipe_radius
def init(self):
self.current_timestep = self.zeroth_timestep
self.dh.run_kernel(self.init_kernel)
def step(self):
# Order matters! First communicate, then boundaries, otherwise
# periodicity handling overwrites reflected populations
# Periodicty
self.periodicity_handler(self.current_timestep)
# Boundaries
self.bh(prev_timestep=self.current_timestep)
# Here, the next time step begins
self.current_timestep = self.current_timestep.next()
# LBM Step
self.dh.run_kernel(self.lb_kernels[self.current_timestep.idx])
# Field Swaps
if not self.inplace:
self.dh.swap(self.pdfs.name, self.pdfs_tmp.name)
# Macroscopic Values
self.dh.run_kernel(self.getter_kernels[self.current_timestep.idx])
def run(self, iterations):
for _ in range(iterations):
self.step()
@property
def velocity_array(self):
if self.gpu:
self.dh.to_cpu(self.velocity_field.name)
return self.dh.gather_array(self.velocity_field.name)
def get_trimmed_velocity_array(self):
if self.gpu:
self.dh.to_cpu(self.velocity_field.name)
u = np.copy(self.dh.gather_array(self.velocity_field.name))
mask = self.bh.get_mask(None, self.wall)
for idx in np.ndindex(u.shape[:-1]):
if mask[idx] != 0:
u[idx] = np.full((self.dim, ), np.nan)
return u
```
%% Cell type:markdown id: tags:
## General Setup
%% Cell type:code id: tags:
``` python
stencil = LBStencil(Stencil.D3Q19)
dim = stencil.D
streaming_pattern = 'aa'
config = ps.CreateKernelConfig(target=target)
viscous_rr = 1.1
force = (0.0001, ) + (0.0,) * (dim - 1)
```
%% Cell type:markdown id: tags:
## 1. Reference: SRT Method
%% Cell type:code id: tags:
``` python
srt_config = LBMConfig(stencil=stencil, method=Method.SRT, relaxation_rate=viscous_rr,
force_model=ForceModel.SIMPLE, force=force, streaming_pattern=streaming_pattern)
srt_flow = PeriodicPipeFlow(srt_config, LBMOptimisation(), config)
srt_flow.init()
srt_flow.run(400)
```
%% Output
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[6], line 4
1 srt_config = LBMConfig(stencil=stencil, method=Method.SRT, relaxation_rate=viscous_rr,
2 force_model=ForceModel.SIMPLE, force=force, streaming_pattern=streaming_pattern)
----> 4 srt_flow = PeriodicPipeFlow(srt_config, LBMOptimisation(), config)
5 srt_flow.init()
6 srt_flow.run(400)
Cell In[4], line 47, in PeriodicPipeFlow.__init__(self, lbm_config, lbm_optimisation, config)
45 for t in self.timesteps:
46 lbm_config = replace(lbm_config, timestep=t, collision_rule=self.lb_collision)
---> 47 self.lb_kernels.append(create_lb_function(lbm_config=lbm_config,
48 lbm_optimisation=lbm_optimisation,
49 config=config))
51 # Macroscopic Values
52 self.density = 1.0
File ~/pystencils/lbmpy/lbmpy/creationfunctions.py:505, in create_lb_function(ast, lbm_config, lbm_optimisation, config, optimization, **kwargs)
502 ast = lbm_config.ast
504 if ast is None:
--> 505 ast = create_lb_ast(lbm_config.update_rule, lbm_config=lbm_config,
506 lbm_optimisation=lbm_optimisation, config=config)
508 res = ast.compile()
510 res.method = ast.method
File ~/pystencils/lbmpy/lbmpy/creationfunctions.py:530, in create_lb_ast(update_rule, lbm_config, lbm_optimisation, config, optimization, **kwargs)
525 update_rule = create_lb_update_rule(lbm_config.collision_rule, lbm_config=lbm_config,
526 lbm_optimisation=lbm_optimisation, config=config)
528 field_types = set(fa.field.dtype for fa in update_rule.defined_symbols if isinstance(fa, Field.Access))
--> 530 config = replace(config, data_type=collate_types(field_types), ghost_layers=1)
531 ast = create_kernel(update_rule, config=config)
533 ast.method = update_rule.method
File /usr/lib/python3.11/dataclasses.py:1492, in replace(obj, **changes)
1485 changes[f.name] = getattr(obj, f.name)
1487 # Create the new object, which calls __init__() and
1488 # __post_init__() (if defined), using all of the init fields we've
1489 # added and/or left in 'changes'. If there are values supplied in
1490 # changes that aren't fields, this will correctly raise a
1491 # TypeError.
-> 1492 return obj.__class__(**changes)
File <string>:24, in __init__(self, target, backend, function_name, data_type, default_number_float, default_number_int, iteration_slice, ghost_layers, cpu_openmp, cpu_vectorize_info, cpu_blocking, omp_single_loop, gpu_indexing, gpu_indexing_params, default_assignment_simplifications, cpu_prepend_optimizations, use_auto_for_assignments, index_fields, coordinate_names, allow_double_writes, skip_independence_check)
File ~/pystencils/pystencils/pystencils/config.py:177, in CreateKernelConfig.__post_init__(self)
174 self._check_type(dtype)
176 if not isinstance(self.data_type, dict):
--> 177 dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans
178 self.data_type = defaultdict(self.DataTypeFactory(dt))
180 if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict):
File /usr/lib/python3.11/copy.py:102, in copy(x)
100 if isinstance(rv, str):
101 return x
--> 102 return _reconstruct(x, None, *rv)
File /usr/lib/python3.11/copy.py:273, in _reconstruct(x, memo, func, args, state, listiter, dictiter, deepcopy)
271 state = deepcopy(state, memo)
272 if hasattr(y, '__setstate__'):
--> 273 y.__setstate__(state)
274 else:
275 if isinstance(state, tuple) and len(state) == 2:
File ~/.local/lib/python3.11/site-packages/sympy/core/basic.py:144, in Basic.__setstate__(self, state)
143 def __setstate__(self, state):
--> 144 for name, value in state.items():
145 setattr(self, name, value)
AttributeError: 'tuple' object has no attribute 'items'
%% Cell type:code id: tags:
``` python
srt_u = srt_flow.get_trimmed_velocity_array()
ps.plot.vector_field_magnitude(srt_u[30,:,:,:])
ps.plot.colorbar()
```
%% Cell type:markdown id: tags:
## 2. Centered Cumulant Method with Central Moment-Space Forcing
%% Cell type:code id: tags:
``` python
cm_config = LBMConfig(method=Method.MONOMIAL_CUMULANT, stencil=stencil, compressible=True,
relaxation_rate=viscous_rr,
force=force, streaming_pattern=streaming_pattern)
lbm_opt = LBMOptimisation(pre_simplification=True)
cm_impl_f_flow = PeriodicPipeFlow(cm_config, lbm_opt, config)
```
%% Cell type:code id: tags:
``` python
lb_method = cm_impl_f_flow.lb_method
assert all(rr == 0 for rr in lb_method.relaxation_rates[1:4])
assert all(rr == viscous_rr for rr in lb_method.relaxation_rates[4:9])
```
%% Cell type:code id: tags:
``` python
cm_impl_f_flow.init()
cm_impl_f_flow.run(400)
```
%% Cell type:code id: tags:
``` python
cm_impl_f_u = cm_impl_f_flow.get_trimmed_velocity_array()
ps.plot.vector_field_magnitude(cm_impl_f_u[30,:,:,:])
ps.plot.colorbar()
```
%% Cell type:code id: tags:
``` python
assert_allclose(cm_impl_f_u, srt_u, rtol=1, atol=1e-4)
```
%% Cell type:markdown id: tags:
## 3. Centered Cumulant Method with Simple force model
%% Cell type:code id: tags:
``` python
from lbmpy.forcemodels import Simple
cm_expl_force_config = LBMConfig(method=Method.CUMULANT, stencil=stencil, compressible=True,
relaxation_rate=viscous_rr, force_model=Simple(force),
force=force, streaming_pattern=streaming_pattern)
lbm_opt = LBMOptimisation(pre_simplification=True)
cm_expl_f_flow = PeriodicPipeFlow(cm_expl_force_config, lbm_opt, config)
```
%% Cell type:code id: tags:
``` python
lb_method = cm_expl_f_flow.lb_method
assert all(rr == 0 for rr in lb_method.relaxation_rates[1:4])
assert all(rr == viscous_rr for rr in lb_method.relaxation_rates[4:9])
```
%% Cell type:code id: tags:
``` python
cm_expl_f_flow.init()
cm_expl_f_flow.run(400)
```
%% Cell type:code id: tags:
``` python
cm_expl_f_u = cm_expl_f_flow.get_trimmed_velocity_array()
ps.plot.vector_field_magnitude(cm_expl_f_u[30,:,:,:])
ps.plot.colorbar()
```
%% Cell type:code id: tags:
``` python
assert_allclose(cm_expl_f_u, srt_u, rtol=1, atol=1e-4)
assert_allclose(cm_expl_f_u, cm_impl_f_u, rtol=1, atol=1e-4)
```