diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index 85e615c5ff9a1e1d23e5548a3350a11bce1abbc7..05081631d02825a11da2a417b64cb586d543f6aa 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -63,9 +63,9 @@ class Comm: @pairs_host_block def reverse_comm(self, reduce=False): self.sim.module_name(f"reverse_comm") - self.prop_list = self.sim.properties.reduction_props() + prop_list = self.sim.properties.reduction_props() - if self.prop_list : + if prop_list : for step in range(self.dom_part.number_of_steps() - 1, -1, -1): if self.sim._target.is_gpu(): CopyArray(self.sim, self.nsend, Contexts.Host, Actions.ReadOnly) @@ -84,9 +84,9 @@ class Comm: Assign(self.sim, self.send_offsets_reverse[j], self.recv_offsets[j]) Assign(self.sim, self.recv_offsets_reverse[j], self.send_offsets[j]) - PackGhostParticlesReverse(self, step, self.prop_list) - CommunicateDataReverse(self, step, self.prop_list) - UnpackGhostParticlesReverse(self, step, self.prop_list, reduce) + PackGhostParticlesReverse(self, step, prop_list) + CommunicateDataReverse(self, step, prop_list) + UnpackGhostParticlesReverse(self, step, prop_list, reduce) @pairs_inline def borders(self): @@ -486,7 +486,7 @@ class UnpackGhostParticlesReverse(Lowerable): nelems = Types.number_of_elements(self.sim, p.type()) for e in range(nelems): if self.reduce: - Assign(self.sim, p[m][e], p[m][e] + recv_buffer_reverse[i][p_offset + e]) + AtomicInc(self.sim, p[m][e], recv_buffer_reverse[i][p_offset + e]) else: Assign(self.sim, p[m][e], recv_buffer_reverse[i][p_offset + e]) @@ -495,7 +495,7 @@ class UnpackGhostParticlesReverse(Lowerable): else: cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Real else x if self.reduce: - Assign(self.sim, p[m], p[m] + cast_fn(recv_buffer_reverse[i][p_offset])) + AtomicInc(self.sim, p[m], cast_fn(recv_buffer_reverse[i][p_offset])) else: Assign(self.sim, p[m], cast_fn(recv_buffer_reverse[i][p_offset])) p_offset += 1 diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index 5565dcc9891fd39a2d2e714fa563b2b4092dbc07..1e1ab5e49abb98d489a2d0a8ed8d54424c2795c5 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -193,6 +193,18 @@ class AddExpressionDeclarations(Mutator): self.declared_exprs.append(atomic_add_id) return ast_node + + def mutate_AtomicInc(self, ast_node): + self.writing = True + ast_node.elem = self.mutate(ast_node.elem) + self.writing = False + ast_node.value = self.mutate(ast_node.value) + atomic_inc_id = id(ast_node) + if atomic_inc_id not in self.declared_exprs and atomic_inc_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(atomic_inc_id) + + return ast_node def mutate_Block(self, ast_node): block_id = id(ast_node)