From b9cb302e4262a3875ab40489d86fa17ae90de83e Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Sun, 9 Apr 2023 10:07:54 +0200 Subject: [PATCH] Regression comm --- lbmpy/advanced_streaming/communication.py | 25 ++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/lbmpy/advanced_streaming/communication.py b/lbmpy/advanced_streaming/communication.py index 0073b807..786c6009 100644 --- a/lbmpy/advanced_streaming/communication.py +++ b/lbmpy/advanced_streaming/communication.py @@ -1,7 +1,8 @@ import itertools from pystencils import CreateKernelConfig, Field, Assignment, AssignmentCollection from pystencils.slicing import shift_slice, get_slice_before_ghost_layer, normalize_slice -from lbmpy.advanced_streaming.utility import is_inplace, get_accessor, numeric_index, Timestep, get_timesteps +from lbmpy.advanced_streaming.utility import is_inplace, get_accessor, numeric_index,\ + Timestep, get_timesteps, numeric_offsets from pystencils.datahandling import SerialDataHandling from pystencils.enums import Target from itertools import chain @@ -132,6 +133,13 @@ def get_communication_slices( origin_slice = get_slice_before_ghost_layer(comm_dir, ghost_layers=ghost_layers, thickness=1) src_slice = _fix_length_one_slices(origin_slice) + write_offsets = numeric_offsets(write_accesses[d]) + tangential_dir = tuple(s - c for s, c in zip(streaming_dir, comm_dir)) + + # TODO: this is just a hotfix. _trim_slice_in_direction breaks FreeSlip BC with adjacent periodic side + if streaming_pattern != "pull": + src_slice = shift_slice(_trim_slice_in_direction(src_slice, tangential_dir), write_offsets) + neighbour_transform = _get_neighbour_transform(comm_dir, ghost_layers) dst_slice = shift_slice(src_slice, neighbour_transform) @@ -210,3 +218,18 @@ def _fix_length_one_slices(slices): return slices else: return tuple(_fix_length_one_slices(s) for s in slices) + + +def _trim_slice_in_direction(slices, direction): + assert len(slices) == len(direction) + + result = [] + for s, d in zip(slices, direction): + if isinstance(s, int): + result.append(s) + continue + start = s.start + 1 if d == -1 else s.start + stop = s.stop - 1 if d == 1 else s.stop + result.append(slice(start, stop, s.step)) + + return tuple(result) -- GitLab