Skip to content
Snippets Groups Projects
Commit 988825e3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add test cases for slice-copy kernels. Some code style cleanup.

parent 019c73c6
No related branches found
No related tags found
1 merge request!182Fix GPU copy kernels in periodicity handling
Pipeline #69868 passed
...@@ -186,7 +186,6 @@ def get_communication_slices( ...@@ -186,7 +186,6 @@ def get_communication_slices(
def periodic_pdf_gpu_copy_kernel(pdf_field, src_slice, dst_slice, domain_size=None): def periodic_pdf_gpu_copy_kernel(pdf_field, src_slice, dst_slice, domain_size=None):
"""Generate a GPU kernel which copies all values from one slice of a field """Generate a GPU kernel which copies all values from one slice of a field
to another non-overlapping slice.""" to another non-overlapping slice."""
# from pystencils.gpu.kernelcreation import create_cuda_kernel
from pystencils import create_kernel from pystencils import create_kernel
pdf_idx = src_slice[-1] pdf_idx = src_slice[-1]
...@@ -206,7 +205,7 @@ def periodic_pdf_gpu_copy_kernel(pdf_field, src_slice, dst_slice, domain_size=No ...@@ -206,7 +205,7 @@ def periodic_pdf_gpu_copy_kernel(pdf_field, src_slice, dst_slice, domain_size=No
return s.start if isinstance(s, slice) else s return s.start if isinstance(s, slice) else s
def _stop(s): def _stop(s):
return s.stop if isinstance(s, slice) else s + 1 return s.stop if isinstance(s, slice) else s
offset = [ offset = [
_start(s1) - _start(s2) _start(s1) - _start(s2)
...@@ -223,7 +222,9 @@ def periodic_pdf_gpu_copy_kernel(pdf_field, src_slice, dst_slice, domain_size=No ...@@ -223,7 +222,9 @@ def periodic_pdf_gpu_copy_kernel(pdf_field, src_slice, dst_slice, domain_size=No
] ]
) )
config = CreateKernelConfig( config = CreateKernelConfig(
iteration_slice=dst_slice, skip_independence_check=True, target=Target.GPU iteration_slice=dst_slice,
skip_independence_check=True,
target=Target.GPU,
) )
ast = create_kernel(copy_eq, config=config) ast = create_kernel(copy_eq, config=config)
......
...@@ -172,7 +172,7 @@ class PdfsToMomentsByMatrixTransform(AbstractRawMomentTransform): ...@@ -172,7 +172,7 @@ class PdfsToMomentsByMatrixTransform(AbstractRawMomentTransform):
# ----------------------------- Private Members ----------------------------- # ----------------------------- Private Members -----------------------------
@ property @property
def _default_simplification(self): def _default_simplification(self):
forward_simp = SimplificationStrategy() forward_simp = SimplificationStrategy()
# forward_simp.add(substitute_moments_in_conserved_quantity_equations) # forward_simp.add(substitute_moments_in_conserved_quantity_equations)
...@@ -218,7 +218,7 @@ class PdfsToMomentsByChimeraTransform(AbstractRawMomentTransform): ...@@ -218,7 +218,7 @@ class PdfsToMomentsByChimeraTransform(AbstractRawMomentTransform):
self.moment_polynomials) self.moment_polynomials)
self.poly_to_mono_matrix = self.mono_to_poly_matrix.inv() self.poly_to_mono_matrix = self.mono_to_poly_matrix.inv()
@ property @property
def absorbs_conserved_quantity_equations(self): def absorbs_conserved_quantity_equations(self):
return True return True
...@@ -414,7 +414,7 @@ class PdfsToMomentsByChimeraTransform(AbstractRawMomentTransform): ...@@ -414,7 +414,7 @@ class PdfsToMomentsByChimeraTransform(AbstractRawMomentTransform):
# ----------------------------- Private Members ----------------------------- # ----------------------------- Private Members -----------------------------
@ property @property
def _default_simplification(self): def _default_simplification(self):
from lbmpy.methods.momentbased.momentbasedsimplifications import ( from lbmpy.methods.momentbased.momentbasedsimplifications import (
substitute_moments_in_conserved_quantity_equations, substitute_moments_in_conserved_quantity_equations,
......
...@@ -9,6 +9,7 @@ from lbmpy.advanced_streaming.communication import ( ...@@ -9,6 +9,7 @@ from lbmpy.advanced_streaming.communication import (
get_communication_slices, get_communication_slices,
_fix_length_one_slices, _fix_length_one_slices,
LBMPeriodicityHandling, LBMPeriodicityHandling,
periodic_pdf_gpu_copy_kernel,
) )
from lbmpy.advanced_streaming.utility import streaming_patterns, Timestep from lbmpy.advanced_streaming.utility import streaming_patterns, Timestep
from lbmpy.enums import Stencil from lbmpy.enums import Stencil
...@@ -87,6 +88,47 @@ def test_pull_communication_slices(stencil): ...@@ -87,6 +88,47 @@ def test_pull_communication_slices(stencil):
assert dst == gl_slice assert dst == gl_slice
@pytest.mark.parametrize("direction", LBStencil(Stencil.D3Q27).stencil_entries)
@pytest.mark.parametrize("pull", [False, True])
def test_gpu_comm_kernels(direction: tuple, pull: bool):
pytest.importorskip("cupy")
stencil = LBStencil(Stencil.D3Q27)
inv_dir = stencil[stencil.inverse_index(direction)]
target = ps.Target.GPU
domain_size = (4,) * stencil.D
dh: ps.datahandling.SerialDataHandling = ps.create_data_handling(
domain_size,
periodicity=(True,) * stencil.D,
parallel=False,
default_target=target,
)
field = dh.add_array("field", values_per_cell=2)
if pull:
dst_slice = get_ghost_region_slice(inv_dir)
src_slice = get_slice_before_ghost_layer(direction)
else:
dst_slice = get_slice_before_ghost_layer(direction)
src_slice = get_ghost_region_slice(inv_dir)
src_slice += (1,)
dst_slice += (1,)
kernel = periodic_pdf_gpu_copy_kernel(field, src_slice, dst_slice)
dh.cpu_arrays[field.name][src_slice] = 42.0
dh.all_to_gpu()
dh.run_kernel(kernel)
dh.all_to_cpu()
np.testing.assert_equal(dh.cpu_arrays[field.name][dst_slice], 42.0)
@pytest.mark.parametrize("stencil", [Stencil.D2Q9, Stencil.D3Q19]) @pytest.mark.parametrize("stencil", [Stencil.D2Q9, Stencil.D3Q19])
@pytest.mark.parametrize("streaming_pattern", streaming_patterns) @pytest.mark.parametrize("streaming_pattern", streaming_patterns)
def test_direct_copy_and_kernels_equivalence(stencil: Stencil, streaming_pattern: str): def test_direct_copy_and_kernels_equivalence(stencil: Stencil, streaming_pattern: str):
...@@ -106,6 +148,9 @@ def test_direct_copy_and_kernels_equivalence(stencil: Stencil, streaming_pattern ...@@ -106,6 +148,9 @@ def test_direct_copy_and_kernels_equivalence(stencil: Stencil, streaming_pattern
pdfs_a = dh.add_array("pdfs_a", values_per_cell=stencil.Q) pdfs_a = dh.add_array("pdfs_a", values_per_cell=stencil.Q)
pdfs_b = dh.add_array("pdfs_b", values_per_cell=stencil.Q) pdfs_b = dh.add_array("pdfs_b", values_per_cell=stencil.Q)
dh.fill(pdfs_a.name, 0.0, ghost_layers=True)
dh.fill(pdfs_b.name, 0.0, ghost_layers=True)
for q in range(stencil.Q): for q in range(stencil.Q):
sl = ps.make_slice[:4, :4, q] if stencil.D == 2 else ps.make_slice[:4, :4, :4, q] sl = ps.make_slice[:4, :4, q] if stencil.D == 2 else ps.make_slice[:4, :4, :4, q]
dh.cpu_arrays[pdfs_a.name][sl] = q dh.cpu_arrays[pdfs_a.name][sl] = q
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment