diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py index 593651567b4e62d89349245528bc96dc7237dc7c..61734d1a914022471b473ef61afe4c735ead08bd 100644 --- a/src/pairs/ir/arrays.py +++ b/src/pairs/ir/arrays.py @@ -54,6 +54,7 @@ class Array(ASTNode): self.arr_layout = a_layout self.arr_sync = a_sync self.arr_ndims = len(self.arr_sizes) + self.arr_strides = {} self.static = False self.device_flag = False Array.last_array_id += 1 @@ -88,6 +89,12 @@ class Array(ASTNode): def is_static(self): return self.static + def set_stride(self, dim, stride): + self.arr_strides[dim] = stride + + def strides(self): + return [self.arr_strides[i] if i in self.arr_strides else self.arr_sizes[i] for i in range(self.arr_ndims)] + def alloc_size(self): return reduce((lambda x, y: x * y), [s for s in self.arr_sizes]) @@ -152,17 +159,18 @@ class ArrayAccess(ASTTerm): def check_and_set_flat_index(self): if len(self.partial_indexes) == self.array.ndims(): sizes = self.array.sizes() + strides = self.array.strides() layout = self.array.layout() if layout == Layouts.AoS: for s in range(0, len(sizes)): self.flat_index = (self.partial_indexes[s] if self.flat_index is None - else self.flat_index * sizes[s] + self.partial_indexes[s]) + else self.flat_index * strides[s] + self.partial_indexes[s]) elif layout == Layouts.SoA: for s in reversed(range(0, len(sizes))): self.flat_index = (self.partial_indexes[s] if self.flat_index is None - else self.flat_index * sizes[s] + self.partial_indexes[s]) + else self.flat_index * strides[s] + self.partial_indexes[s]) else: raise Exception("Invalid data layout!") diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index 8ce93a477acd74ada38c8d58d20ec31a081cd467..8fd6c49373d85aef638527742bae4c5e4627b73c 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -162,9 +162,9 @@ class PackGhostParticles(Lowerable): @pairs_device_block def lower(self): send_buffer = self.comm.send_buffer + send_buffer.set_stride(1, self.get_elems_per_particle()) send_map = self.comm.send_map send_mult = self.comm.send_mult - elems_per_particle = self.get_elems_per_particle() self.sim.module_name(f"pack_ghost_particles{self.step}_" + "_".join([str(p.id()) for p in self.prop_list])) step_indexes = self.comm.dom_part.step_indexes(self.step) @@ -172,7 +172,6 @@ class PackGhostParticles(Lowerable): for i in For(self.sim, start, start + sum([self.comm.nsend[j] for j in step_indexes])): p_offset = 0 m = send_map[i] - buffer_index = i * elems_per_particle for p in self.prop_list: if p.type() == Types.Vector: for d in range(self.sim.ndims()): @@ -180,13 +179,13 @@ class PackGhostParticles(Lowerable): if p == self.sim.position(): src += send_mult[i][d] * self.sim.grid.length(d) - send_buffer[buffer_index][p_offset + d].set(src) + send_buffer[i][p_offset + d].set(src) p_offset += self.sim.ndims() else: cast_fn = lambda x: Cast(self.sim, x, Types.Double) if p.type() != Types.Double else x - send_buffer[buffer_index][p_offset].set(cast_fn(p[m])) + send_buffer[i][p_offset].set(cast_fn(p[m])) p_offset += 1 @@ -205,24 +204,23 @@ class UnpackGhostParticles(Lowerable): def lower(self): nlocal = self.sim.nlocal recv_buffer = self.comm.recv_buffer - elems_per_particle = self.get_elems_per_particle() + recv_buffer.set_stride(1, self.get_elems_per_particle()) self.sim.module_name(f"unpack_ghost_particles{self.step}_" + "_".join([str(p.id()) for p in self.prop_list])) step_indexes = self.comm.dom_part.step_indexes(self.step) start = self.comm.recv_offsets[step_indexes[0]] for i in For(self.sim, start, start + sum([self.comm.nrecv[j] for j in step_indexes])): p_offset = 0 - buffer_index = i * elems_per_particle for p in self.prop_list: if p.type() == Types.Vector: for d in range(self.sim.ndims()): - p[nlocal + i][d].set(recv_buffer[buffer_index][p_offset + d]) - + p[nlocal + i][d].set(recv_buffer[i][p_offset + d]) + p_offset += self.sim.ndims() else: cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Double else x - p[nlocal + i].set(cast_fn(recv_buffer[buffer_index][p_offset])) + p[nlocal + i].set(cast_fn(recv_buffer[i][p_offset])) p_offset += 1