diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index 593651567b4e62d89349245528bc96dc7237dc7c..61734d1a914022471b473ef61afe4c735ead08bd 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -54,6 +54,7 @@ class Array(ASTNode):
         self.arr_layout = a_layout
         self.arr_sync = a_sync
         self.arr_ndims = len(self.arr_sizes)
+        self.arr_strides = {}
         self.static = False
         self.device_flag = False
         Array.last_array_id += 1
@@ -88,6 +89,12 @@ class Array(ASTNode):
     def is_static(self):
         return self.static
 
+    def set_stride(self, dim, stride):
+        self.arr_strides[dim] = stride
+
+    def strides(self):
+        return [self.arr_strides[i] if i in self.arr_strides else self.arr_sizes[i] for i in range(self.arr_ndims)]
+
     def alloc_size(self):
         return reduce((lambda x, y: x * y), [s for s in self.arr_sizes])
 
@@ -152,17 +159,18 @@ class ArrayAccess(ASTTerm):
     def check_and_set_flat_index(self):
         if len(self.partial_indexes) == self.array.ndims():
             sizes = self.array.sizes()
+            strides = self.array.strides()
             layout = self.array.layout()
 
             if layout == Layouts.AoS:
                 for s in range(0, len(sizes)):
                     self.flat_index = (self.partial_indexes[s] if self.flat_index is None
-                                       else self.flat_index * sizes[s] + self.partial_indexes[s])
+                                       else self.flat_index * strides[s] + self.partial_indexes[s])
 
             elif layout == Layouts.SoA:
                 for s in reversed(range(0, len(sizes))):
                     self.flat_index = (self.partial_indexes[s] if self.flat_index is None
-                                       else self.flat_index * sizes[s] + self.partial_indexes[s])
+                                       else self.flat_index * strides[s] + self.partial_indexes[s])
 
             else:
                 raise Exception("Invalid data layout!")
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 8ce93a477acd74ada38c8d58d20ec31a081cd467..8fd6c49373d85aef638527742bae4c5e4627b73c 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -162,9 +162,9 @@ class PackGhostParticles(Lowerable):
     @pairs_device_block
     def lower(self):
         send_buffer = self.comm.send_buffer
+        send_buffer.set_stride(1, self.get_elems_per_particle())
         send_map = self.comm.send_map
         send_mult = self.comm.send_mult
-        elems_per_particle = self.get_elems_per_particle()
         self.sim.module_name(f"pack_ghost_particles{self.step}_" + "_".join([str(p.id()) for p in self.prop_list]))
 
         step_indexes = self.comm.dom_part.step_indexes(self.step)
@@ -172,7 +172,6 @@ class PackGhostParticles(Lowerable):
         for i in For(self.sim, start, start + sum([self.comm.nsend[j] for j in step_indexes])):
             p_offset = 0
             m = send_map[i]
-            buffer_index = i * elems_per_particle
             for p in self.prop_list:
                 if p.type() == Types.Vector:
                     for d in range(self.sim.ndims()):
@@ -180,13 +179,13 @@ class PackGhostParticles(Lowerable):
                         if p == self.sim.position():
                             src += send_mult[i][d] * self.sim.grid.length(d)
 
-                        send_buffer[buffer_index][p_offset + d].set(src)
+                        send_buffer[i][p_offset + d].set(src)
 
                     p_offset += self.sim.ndims()
 
                 else:
                     cast_fn = lambda x: Cast(self.sim, x, Types.Double) if p.type() != Types.Double else x
-                    send_buffer[buffer_index][p_offset].set(cast_fn(p[m]))
+                    send_buffer[i][p_offset].set(cast_fn(p[m]))
                     p_offset += 1
 
             
@@ -205,24 +204,23 @@ class UnpackGhostParticles(Lowerable):
     def lower(self):
         nlocal = self.sim.nlocal
         recv_buffer = self.comm.recv_buffer
-        elems_per_particle = self.get_elems_per_particle()
+        recv_buffer.set_stride(1, self.get_elems_per_particle())
         self.sim.module_name(f"unpack_ghost_particles{self.step}_" + "_".join([str(p.id()) for p in self.prop_list]))
 
         step_indexes = self.comm.dom_part.step_indexes(self.step)
         start = self.comm.recv_offsets[step_indexes[0]]
         for i in For(self.sim, start, start + sum([self.comm.nrecv[j] for j in step_indexes])):
             p_offset = 0
-            buffer_index = i * elems_per_particle
             for p in self.prop_list:
                 if p.type() == Types.Vector:
                     for d in range(self.sim.ndims()):
-                        p[nlocal + i][d].set(recv_buffer[buffer_index][p_offset + d])
-                        
+                        p[nlocal + i][d].set(recv_buffer[i][p_offset + d])
+
                     p_offset += self.sim.ndims()
 
                 else:
                     cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Double else x
-                    p[nlocal + i].set(cast_fn(recv_buffer[buffer_index][p_offset]))
+                    p[nlocal + i].set(cast_fn(recv_buffer[i][p_offset]))
                     p_offset += 1