diff --git a/src/lbmpy/boundaries/boundaryhandling.py b/src/lbmpy/boundaries/boundaryhandling.py index cf8f4239a66affab7639bbe4c32bd1c1802b64aa..9e1e710493228c9caf763a160c1a91072c5865d2 100644 --- a/src/lbmpy/boundaries/boundaryhandling.py +++ b/src/lbmpy/boundaries/boundaryhandling.py @@ -10,6 +10,8 @@ from pystencils.stencil import inverse_direction from lbmpy.advanced_streaming.indexing import BetweenTimestepsIndexing from lbmpy.advanced_streaming.utility import is_inplace, Timestep, AccessPdfValues +from .._compat import IS_PYSTENCILS_2 + class LatticeBoltzmannBoundaryHandling(BoundaryHandling): """ @@ -19,13 +21,16 @@ class LatticeBoltzmannBoundaryHandling(BoundaryHandling): """ def __init__(self, lb_method, data_handling, pdf_field_name, streaming_pattern='pull', - name="boundary_handling", flag_interface=None, target=Target.CPU, openmp=False): + name="boundary_handling", flag_interface=None, target=Target.CPU, openmp=False, **kwargs): self._lb_method = lb_method self._streaming_pattern = streaming_pattern self._inplace = is_inplace(streaming_pattern) self._prev_timestep = None - super(LatticeBoltzmannBoundaryHandling, self).__init__(data_handling, pdf_field_name, lb_method.stencil, - name, flag_interface, target, openmp) + super(LatticeBoltzmannBoundaryHandling, self).__init__( + data_handling, pdf_field_name, lb_method.stencil, + name, flag_interface, target=target, openmp=openmp, + **kwargs + ) # ------------------------- Overridden methods of pystencils.BoundaryHandling ------------------------- @@ -66,10 +71,15 @@ class LatticeBoltzmannBoundaryHandling(BoundaryHandling): return self._boundary_object_to_boundary_info[boundary_obj].flag def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj, prev_timestep=Timestep.BOTH): + if IS_PYSTENCILS_2: + additional_args = {"default_dtype": self._default_dtype} + else: + additional_args = dict() + return create_lattice_boltzmann_boundary_kernel( symbolic_field, symbolic_index_field, self._lb_method, boundary_obj, prev_timestep=prev_timestep, streaming_pattern=self._streaming_pattern, - target=self._target, cpu_openmp=self._openmp) + target=self._target, cpu_openmp=self._openmp, **additional_args) class InplaceStreamingBoundaryInfo(object): diff --git a/src/lbmpy/lbstep.py b/src/lbmpy/lbstep.py index 53300d7b1b74a6f13a9876d28d7da99696895bfb..949b97a3265d8fb0c1fc48726a0c46beb19e8f47 100644 --- a/src/lbmpy/lbstep.py +++ b/src/lbmpy/lbstep.py @@ -31,7 +31,9 @@ class LatticeBoltzmannStep: velocity_input_array_name=None, time_step_order='stream_collide', flag_interface=None, alignment_if_vectorized=64, fixed_loop_sizes=True, timeloop_creation_function=TimeLoop, - lbm_config=None, lbm_optimisation=None, config=None, **method_parameters): + lbm_config=None, lbm_optimisation=None, + config: CreateKernelConfig | None = None, + **method_parameters): if optimization is None: optimization = {} @@ -43,7 +45,10 @@ class LatticeBoltzmannStep: raise ValueError("When passing a data_handling, the domain_size parameter can not be specified") if config is not None: - target = config.target + if IS_PYSTENCILS_2: + target = config.get_target() + else: + target = config.target else: target = optimization.get('target', Target.CPU) @@ -166,10 +171,14 @@ class LatticeBoltzmannStep: self._sync_tmp = data_handling.synchronization_function([self._tmp_arr_name], stencil_name, target, stencil_restricted=True) - self._boundary_handling = LatticeBoltzmannBoundaryHandling(self.method, self._data_handling, self._pdf_arr_name, - name=name + "_boundary_handling", - flag_interface=flag_interface, - target=target, openmp=config.cpu_openmp) + self._boundary_handling = LatticeBoltzmannBoundaryHandling( + self.method, self._data_handling, self._pdf_arr_name, + name=name + "_boundary_handling", + flag_interface=flag_interface, + target=target, + openmp=config.cpu_openmp, + **({"default_dtype": field_dtype} if IS_PYSTENCILS_2 else dict()) + ) self._lbm_config = lbm_config self._lbm_optimisation = lbm_optimisation diff --git a/src/lbmpy/scenarios.py b/src/lbmpy/scenarios.py index 1bc87b5d1816dd734c3eb44400de64e7ce226d21..28bf1da6cb62377a1314753511a3866be141bdd3 100644 --- a/src/lbmpy/scenarios.py +++ b/src/lbmpy/scenarios.py @@ -31,6 +31,7 @@ from lbmpy._compat import IS_PYSTENCILS_2 from lbmpy.boundaries import UBB, FixedDensity, NoSlip from lbmpy.geometry import add_pipe_inflow_boundary, add_pipe_walls from lbmpy.lbstep import LatticeBoltzmannStep +from pystencils import Target from pystencils.datahandling import create_data_handling from pystencils.slicing import slice_from_direction @@ -87,7 +88,7 @@ def create_lid_driven_cavity(domain_size=None, lid_velocity=0.005, lbm_kernel=No assert domain_size is not None or data_handling is not None if data_handling is None: optimization = kwargs.get('optimization', None) - target = optimization.get('target', None) if optimization else None + target = optimization.get('target', None) if optimization else Target.CPU data_handling = create_data_handling(domain_size, periodicity=False, default_ghost_layers=1, diff --git a/tests/test_float_kernel.py b/tests/test_float_kernel.py index aac3cdb870aa73f55eed19edca0680de03031d29..20a645188b51fcd57245ddfce68d653e3aedce66 100644 --- a/tests/test_float_kernel.py +++ b/tests/test_float_kernel.py @@ -37,11 +37,11 @@ def test_creation(double_precision, method_enum): assert "float" in code -@pytest.mark.parametrize("double_precision", [False, True]) +@pytest.mark.parametrize("numeric_type", ["float32", "float64"]) @pytest.mark.parametrize( "method_enum", [Method.SRT, Method.CENTRAL_MOMENT, Method.CUMULANT] ) -def test_scenario(method_enum, double_precision): +def test_scenario(method_enum, numeric_type): lbm_config = LBMConfig( stencil=LBStencil(Stencil.D3Q27), method=method_enum, @@ -51,18 +51,18 @@ def test_scenario(method_enum, double_precision): if IS_PYSTENCILS_2: config = ps.CreateKernelConfig( - default_dtype="float64" if double_precision else "float32" + default_dtype=numeric_type ) else: config = ps.CreateKernelConfig( - data_type="float64" if double_precision else "float32", - default_number_float="float64" if double_precision else "float32", + data_type=numeric_type, + default_number_float=numeric_type ) sc = create_lid_driven_cavity((16, 16, 8), lbm_config=lbm_config, config=config) sc.run(1) code = ps.get_code_str(sc.ast) - if double_precision: + if numeric_type == "float64": assert "float" not in code assert "double" in code else: