Skip to content
Snippets Groups Projects
Commit 516e464c authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Add separate declarations for AtomicAdd

parent 1c6392dc
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,18 @@ class SetBlockVariants(Mutator): ...@@ -29,6 +29,18 @@ class SetBlockVariants(Mutator):
self.in_assignment = None self.in_assignment = None
return ast_node return ast_node
def mutate_AtomicAdd(self, ast_node):
self.in_assignment = ast_node
ast_node.elem = self.mutate(ast_node.elem)
self.in_assignment = None
ast_node.value = self.mutate(ast_node.value)
if ast_node.check_for_resize():
ast_node.resize = self.mutate(ast_node.resize)
ast_node.capacity = self.mutate(ast_node.capacity)
return ast_node
def mutate_For(self, ast_node): def mutate_For(self, ast_node):
self.push_variant(ast_node.iterator) self.push_variant(ast_node.iterator)
ast_node.block.add_variant(ast_node.iterator.name()) ast_node.block.add_variant(ast_node.iterator.name())
......
...@@ -248,6 +248,20 @@ class CGen: ...@@ -248,6 +248,20 @@ class CGen:
index_g = self.generate_expression(prop_access.index) index_g = self.generate_expression(prop_access.index)
self.print(f"const {tkw} {acc_ref} = {prop_name}[{index_g}];") self.print(f"const {tkw} {acc_ref} = {prop_name}[{index_g}];")
if isinstance(ast_node.elem, AtomicAdd):
atomic_add = ast_node.elem
elem = self.generate_expression(atomic_add.elem)
value = self.generate_expression(atomic_add.value)
tkw = Types.c_keyword(atomic_add.type())
acc_ref = f"atm_add{atomic_add.id()}"
if atomic_add.check_for_resize():
resize = self.generate_expression(atomic_add.resize)
capacity = self.generate_expression(atomic_add.capacity)
self.print(f"const {tkw} {acc_ref} = pairs::atomic_add_resize_check(&({elem}), {value}, &({resize}), {capacity});")
else:
self.print(f"const {tkw} {acc_ref} = pairs::atomic_add(&({elem}), {value});")
if isinstance(ast_node, Branch): if isinstance(ast_node, Branch):
cond = self.generate_expression(ast_node.cond) cond = self.generate_expression(ast_node.cond)
self.print(f"if({cond}) {{") self.print(f"if({cond}) {{")
...@@ -490,36 +504,25 @@ class CGen: ...@@ -490,36 +504,25 @@ class CGen:
return ast_node.name() return ast_node.name()
if isinstance(ast_node, ArrayAccess): if isinstance(ast_node, ArrayAccess):
array_name = self.generate_expression(ast_node.array)
acc_index = self.generate_expression(ast_node.flat_index)
if mem or ast_node.inlined is True: if mem or ast_node.inlined is True:
array_name = self.generate_expression(ast_node.array)
acc_index = self.generate_expression(ast_node.flat_index)
return f"{array_name}[{acc_index}]" return f"{array_name}[{acc_index}]"
return f"a{ast_node.id()}" return f"a{ast_node.id()}"
if isinstance(ast_node, AtomicAdd): if isinstance(ast_node, AtomicAdd):
elem = self.generate_expression(ast_node.elem) return f"atm_add{ast_node.id()}"
value = self.generate_expression(ast_node.value)
if ast_node.check_for_resize():
resize = self.generate_expression(ast_node.resize)
capacity = self.generate_expression(ast_node.capacity)
return f"pairs::atomic_add_resize_check(&({elem}), {value}, &({resize}), {capacity})"
else:
return f"pairs::atomic_add(&({elem}), {value})"
if isinstance(ast_node, BinOp): if isinstance(ast_node, BinOp):
lhs = self.generate_expression(ast_node.lhs, mem, index)
rhs = self.generate_expression(ast_node.rhs, index=index)
operator = ast_node.operator()
if ast_node.inlined is True: if ast_node.inlined is True:
assert ast_node.type() != Types.Vector, "Vector operations cannot be inlined!" assert ast_node.type() != Types.Vector, "Vector operations cannot be inlined!"
lhs = self.generate_expression(ast_node.lhs, mem, index)
rhs = self.generate_expression(ast_node.rhs, index=index)
operator = ast_node.operator()
return f"({lhs} {operator.symbol()} {rhs})" return f"({lhs} {operator.symbol()} {rhs})"
if ast_node.is_vector_kind(): if ast_node.is_vector_kind():
if index is None:
print(ast_node)
assert index is not None, "Index must be set for vector reference!" assert index is not None, "Index must be set for vector reference!"
return f"e{ast_node.id()}[{index}]" if ast_node.mem else f"e{ast_node.id()}_{index}" return f"e{ast_node.id()}[{index}]" if ast_node.mem else f"e{ast_node.id()}_{index}"
......
...@@ -3,8 +3,15 @@ from pairs.ir.lit import Lit ...@@ -3,8 +3,15 @@ from pairs.ir.lit import Lit
class AtomicAdd(ASTTerm): class AtomicAdd(ASTTerm):
last_atomic_add = 0
def new_id():
AtomicAdd.last_atomic_add += 1
return AtomicAdd.last_atomic_add - 1
def __init__(self, sim, elem, value): def __init__(self, sim, elem, value):
super().__init__(sim) super().__init__(sim)
self.atomic_add_id = AtomicAdd.new_id()
self.elem = BinOp.inline(elem) self.elem = BinOp.inline(elem)
self.value = Lit.cvt(sim, value) self.value = Lit.cvt(sim, value)
self.resize = None self.resize = None
...@@ -20,6 +27,9 @@ class AtomicAdd(ASTTerm): ...@@ -20,6 +27,9 @@ class AtomicAdd(ASTTerm):
def check_for_resize(self): def check_for_resize(self):
return self.resize is not None return self.resize is not None
def id(self):
return self.atomic_add_id
def type(self): def type(self):
return self.elem.type() return self.elem.type()
......
...@@ -127,6 +127,16 @@ class AddExpressionDeclarations(Mutator): ...@@ -127,6 +127,16 @@ class AddExpressionDeclarations(Mutator):
return ast_node return ast_node
def mutate_AtomicAdd(self, ast_node):
ast_node.elem = self.mutate(ast_node.elem)
ast_node.value = self.mutate(ast_node.value)
atomic_add_id = id(ast_node)
if atomic_add_id not in self.declared_exprs and atomic_add_id not in self.params:
self.push_decl(Decl(ast_node.sim, ast_node))
self.declared_exprs.append(atomic_add_id)
return ast_node
def mutate_Block(self, ast_node): def mutate_Block(self, ast_node):
block_id = id(ast_node) block_id = id(ast_node)
self.decls[block_id] = [] self.decls[block_id] = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment