Skip to content
Snippets Groups Projects
Commit 5586d82d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

fix data type settings in lb scenarios and boundary handling

parent e1d24200
No related branches found
No related tags found
1 merge request!172Changes for compatibility with pystencils 2.0
......@@ -10,6 +10,8 @@ from pystencils.stencil import inverse_direction
from lbmpy.advanced_streaming.indexing import BetweenTimestepsIndexing
from lbmpy.advanced_streaming.utility import is_inplace, Timestep, AccessPdfValues
from .._compat import IS_PYSTENCILS_2
class LatticeBoltzmannBoundaryHandling(BoundaryHandling):
"""
......@@ -19,13 +21,16 @@ class LatticeBoltzmannBoundaryHandling(BoundaryHandling):
"""
def __init__(self, lb_method, data_handling, pdf_field_name, streaming_pattern='pull',
name="boundary_handling", flag_interface=None, target=Target.CPU, openmp=False):
name="boundary_handling", flag_interface=None, target=Target.CPU, openmp=False, **kwargs):
self._lb_method = lb_method
self._streaming_pattern = streaming_pattern
self._inplace = is_inplace(streaming_pattern)
self._prev_timestep = None
super(LatticeBoltzmannBoundaryHandling, self).__init__(data_handling, pdf_field_name, lb_method.stencil,
name, flag_interface, target, openmp)
super(LatticeBoltzmannBoundaryHandling, self).__init__(
data_handling, pdf_field_name, lb_method.stencil,
name, flag_interface, target=target, openmp=openmp,
**kwargs
)
# ------------------------- Overridden methods of pystencils.BoundaryHandling -------------------------
......@@ -66,10 +71,15 @@ class LatticeBoltzmannBoundaryHandling(BoundaryHandling):
return self._boundary_object_to_boundary_info[boundary_obj].flag
def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj, prev_timestep=Timestep.BOTH):
if IS_PYSTENCILS_2:
additional_args = {"default_dtype": self._default_dtype}
else:
additional_args = dict()
return create_lattice_boltzmann_boundary_kernel(
symbolic_field, symbolic_index_field, self._lb_method, boundary_obj,
prev_timestep=prev_timestep, streaming_pattern=self._streaming_pattern,
target=self._target, cpu_openmp=self._openmp)
target=self._target, cpu_openmp=self._openmp, **additional_args)
class InplaceStreamingBoundaryInfo(object):
......
......@@ -31,7 +31,9 @@ class LatticeBoltzmannStep:
velocity_input_array_name=None, time_step_order='stream_collide', flag_interface=None,
alignment_if_vectorized=64, fixed_loop_sizes=True,
timeloop_creation_function=TimeLoop,
lbm_config=None, lbm_optimisation=None, config=None, **method_parameters):
lbm_config=None, lbm_optimisation=None,
config: CreateKernelConfig | None = None,
**method_parameters):
if optimization is None:
optimization = {}
......@@ -43,7 +45,10 @@ class LatticeBoltzmannStep:
raise ValueError("When passing a data_handling, the domain_size parameter can not be specified")
if config is not None:
target = config.target
if IS_PYSTENCILS_2:
target = config.get_target()
else:
target = config.target
else:
target = optimization.get('target', Target.CPU)
......@@ -166,10 +171,14 @@ class LatticeBoltzmannStep:
self._sync_tmp = data_handling.synchronization_function([self._tmp_arr_name], stencil_name, target,
stencil_restricted=True)
self._boundary_handling = LatticeBoltzmannBoundaryHandling(self.method, self._data_handling, self._pdf_arr_name,
name=name + "_boundary_handling",
flag_interface=flag_interface,
target=target, openmp=config.cpu_openmp)
self._boundary_handling = LatticeBoltzmannBoundaryHandling(
self.method, self._data_handling, self._pdf_arr_name,
name=name + "_boundary_handling",
flag_interface=flag_interface,
target=target,
openmp=config.cpu_openmp,
**({"default_dtype": field_dtype} if IS_PYSTENCILS_2 else dict())
)
self._lbm_config = lbm_config
self._lbm_optimisation = lbm_optimisation
......
......@@ -31,6 +31,7 @@ from lbmpy._compat import IS_PYSTENCILS_2
from lbmpy.boundaries import UBB, FixedDensity, NoSlip
from lbmpy.geometry import add_pipe_inflow_boundary, add_pipe_walls
from lbmpy.lbstep import LatticeBoltzmannStep
from pystencils import Target
from pystencils.datahandling import create_data_handling
from pystencils.slicing import slice_from_direction
......@@ -87,7 +88,7 @@ def create_lid_driven_cavity(domain_size=None, lid_velocity=0.005, lbm_kernel=No
assert domain_size is not None or data_handling is not None
if data_handling is None:
optimization = kwargs.get('optimization', None)
target = optimization.get('target', None) if optimization else None
target = optimization.get('target', None) if optimization else Target.CPU
data_handling = create_data_handling(domain_size,
periodicity=False,
default_ghost_layers=1,
......
......@@ -37,11 +37,11 @@ def test_creation(double_precision, method_enum):
assert "float" in code
@pytest.mark.parametrize("double_precision", [False, True])
@pytest.mark.parametrize("numeric_type", ["float32", "float64"])
@pytest.mark.parametrize(
"method_enum", [Method.SRT, Method.CENTRAL_MOMENT, Method.CUMULANT]
)
def test_scenario(method_enum, double_precision):
def test_scenario(method_enum, numeric_type):
lbm_config = LBMConfig(
stencil=LBStencil(Stencil.D3Q27),
method=method_enum,
......@@ -51,18 +51,18 @@ def test_scenario(method_enum, double_precision):
if IS_PYSTENCILS_2:
config = ps.CreateKernelConfig(
default_dtype="float64" if double_precision else "float32"
default_dtype=numeric_type
)
else:
config = ps.CreateKernelConfig(
data_type="float64" if double_precision else "float32",
default_number_float="float64" if double_precision else "float32",
data_type=numeric_type,
default_number_float=numeric_type
)
sc = create_lid_driven_cavity((16, 16, 8), lbm_config=lbm_config, config=config)
sc.run(1)
code = ps.get_code_str(sc.ast)
if double_precision:
if numeric_type == "float64":
assert "float" not in code
assert "double" in code
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment