Skip to content
Snippets Groups Projects
Commit 468aef7a authored by Markus Holzer's avatar Markus Holzer
Browse files

regression

parent 86d900b1
No related branches found
No related tags found
1 merge request!141Regression comm
import itertools import itertools
from pystencils import CreateKernelConfig, Field, Assignment, AssignmentCollection from pystencils import CreateKernelConfig, Field, Assignment, AssignmentCollection
from pystencils.slicing import shift_slice, get_slice_before_ghost_layer, normalize_slice from pystencils.slicing import shift_slice, get_slice_before_ghost_layer, normalize_slice
from lbmpy.advanced_streaming.utility import is_inplace, get_accessor, numeric_index, Timestep, get_timesteps from lbmpy.advanced_streaming.utility import is_inplace, get_accessor, numeric_index, Timestep, get_timesteps, numeric_offsets
from pystencils.datahandling import SerialDataHandling from pystencils.datahandling import SerialDataHandling
from pystencils.enums import Target from pystencils.enums import Target
from itertools import chain from itertools import chain
...@@ -132,6 +132,10 @@ def get_communication_slices( ...@@ -132,6 +132,10 @@ def get_communication_slices(
origin_slice = get_slice_before_ghost_layer(comm_dir, ghost_layers=ghost_layers, thickness=1) origin_slice = get_slice_before_ghost_layer(comm_dir, ghost_layers=ghost_layers, thickness=1)
src_slice = _fix_length_one_slices(origin_slice) src_slice = _fix_length_one_slices(origin_slice)
write_offsets = numeric_offsets(write_accesses[d])
tangential_dir = tuple(s - c for s, c in zip(streaming_dir, comm_dir))
src_slice = shift_slice(_trim_slice_in_direction(src_slice, tangential_dir), write_offsets)
neighbour_transform = _get_neighbour_transform(comm_dir, ghost_layers) neighbour_transform = _get_neighbour_transform(comm_dir, ghost_layers)
dst_slice = shift_slice(src_slice, neighbour_transform) dst_slice = shift_slice(src_slice, neighbour_transform)
...@@ -210,3 +214,19 @@ def _fix_length_one_slices(slices): ...@@ -210,3 +214,19 @@ def _fix_length_one_slices(slices):
return slices return slices
else: else:
return tuple(_fix_length_one_slices(s) for s in slices) return tuple(_fix_length_one_slices(s) for s in slices)
def _trim_slice_in_direction(slices, direction):
assert len(slices) == len(direction)
result = []
for s, d in zip(slices, direction):
if isinstance(s, int):
result.append(s)
continue
start = s.start + 1 if d == -1 else s.start
stop = s.stop - 1 if d == 1 else s.stop
result.append(slice(start, stop, s.step))
return tuple(result)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment