diff --git a/src/pairs/analysis/blocks.py b/src/pairs/analysis/blocks.py index a98a77183132b8df0172b65bba229c4391acfcfd..81714e4eec6eda17cdfc4f451874e1192c63a774 100644 --- a/src/pairs/analysis/blocks.py +++ b/src/pairs/analysis/blocks.py @@ -29,6 +29,18 @@ class SetBlockVariants(Mutator): self.in_assignment = None return ast_node + def mutate_AtomicAdd(self, ast_node): + self.in_assignment = ast_node + ast_node.elem = self.mutate(ast_node.elem) + self.in_assignment = None + ast_node.value = self.mutate(ast_node.value) + + if ast_node.check_for_resize(): + ast_node.resize = self.mutate(ast_node.resize) + ast_node.capacity = self.mutate(ast_node.capacity) + + return ast_node + def mutate_For(self, ast_node): self.push_variant(ast_node.iterator) ast_node.block.add_variant(ast_node.iterator.name()) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 648acd3c4aaa17f60ad42b790e88dc8448b09309..e7e5fec49692bd3481ad5ed11675000738a2fe19 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -248,6 +248,20 @@ class CGen: index_g = self.generate_expression(prop_access.index) self.print(f"const {tkw} {acc_ref} = {prop_name}[{index_g}];") + if isinstance(ast_node.elem, AtomicAdd): + atomic_add = ast_node.elem + elem = self.generate_expression(atomic_add.elem) + value = self.generate_expression(atomic_add.value) + tkw = Types.c_keyword(atomic_add.type()) + acc_ref = f"atm_add{atomic_add.id()}" + + if atomic_add.check_for_resize(): + resize = self.generate_expression(atomic_add.resize) + capacity = self.generate_expression(atomic_add.capacity) + self.print(f"const {tkw} {acc_ref} = pairs::atomic_add_resize_check(&({elem}), {value}, &({resize}), {capacity});") + else: + self.print(f"const {tkw} {acc_ref} = pairs::atomic_add(&({elem}), {value});") + if isinstance(ast_node, Branch): cond = self.generate_expression(ast_node.cond) self.print(f"if({cond}) {{") @@ -490,36 +504,25 @@ class CGen: return ast_node.name() if isinstance(ast_node, ArrayAccess): - array_name = self.generate_expression(ast_node.array) - acc_index = self.generate_expression(ast_node.flat_index) if mem or ast_node.inlined is True: + array_name = self.generate_expression(ast_node.array) + acc_index = self.generate_expression(ast_node.flat_index) return f"{array_name}[{acc_index}]" return f"a{ast_node.id()}" if isinstance(ast_node, AtomicAdd): - elem = self.generate_expression(ast_node.elem) - value = self.generate_expression(ast_node.value) - if ast_node.check_for_resize(): - resize = self.generate_expression(ast_node.resize) - capacity = self.generate_expression(ast_node.capacity) - return f"pairs::atomic_add_resize_check(&({elem}), {value}, &({resize}), {capacity})" - else: - return f"pairs::atomic_add(&({elem}), {value})" + return f"atm_add{ast_node.id()}" if isinstance(ast_node, BinOp): - lhs = self.generate_expression(ast_node.lhs, mem, index) - rhs = self.generate_expression(ast_node.rhs, index=index) - operator = ast_node.operator() - if ast_node.inlined is True: assert ast_node.type() != Types.Vector, "Vector operations cannot be inlined!" + lhs = self.generate_expression(ast_node.lhs, mem, index) + rhs = self.generate_expression(ast_node.rhs, index=index) + operator = ast_node.operator() return f"({lhs} {operator.symbol()} {rhs})" if ast_node.is_vector_kind(): - if index is None: - print(ast_node) - assert index is not None, "Index must be set for vector reference!" return f"e{ast_node.id()}[{index}]" if ast_node.mem else f"e{ast_node.id()}_{index}" diff --git a/src/pairs/ir/atomic.py b/src/pairs/ir/atomic.py index 11de9bf6aeefe5e82b037c66e8579f5fc8cf03ce..c492b35d844f90436708c6810e12aeb3e8a9704d 100644 --- a/src/pairs/ir/atomic.py +++ b/src/pairs/ir/atomic.py @@ -3,8 +3,15 @@ from pairs.ir.lit import Lit class AtomicAdd(ASTTerm): + last_atomic_add = 0 + + def new_id(): + AtomicAdd.last_atomic_add += 1 + return AtomicAdd.last_atomic_add - 1 + def __init__(self, sim, elem, value): super().__init__(sim) + self.atomic_add_id = AtomicAdd.new_id() self.elem = BinOp.inline(elem) self.value = Lit.cvt(sim, value) self.resize = None @@ -20,6 +27,9 @@ class AtomicAdd(ASTTerm): def check_for_resize(self): return self.resize is not None + def id(self): + return self.atomic_add_id + def type(self): return self.elem.type() diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index 37367637b96e06cee8787ef45a4a768312052fef..07f987b2571ac9fe75edada1cc03084f17e1aa68 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -127,6 +127,16 @@ class AddExpressionDeclarations(Mutator): return ast_node + def mutate_AtomicAdd(self, ast_node): + ast_node.elem = self.mutate(ast_node.elem) + ast_node.value = self.mutate(ast_node.value) + atomic_add_id = id(ast_node) + if atomic_add_id not in self.declared_exprs and atomic_add_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(atomic_add_id) + + return ast_node + def mutate_Block(self, ast_node): block_id = id(ast_node) self.decls[block_id] = []