Skip to content
Snippets Groups Projects
Commit 4a4649aa authored by Rafael Ravedutti Lucio Machado's avatar Rafael Ravedutti Lucio Machado
Browse files

Fix LICM transformation

parent 5dd75387
No related branches found
No related tags found
No related merge requests found
......@@ -52,7 +52,9 @@ class BinOp(ASTNode):
self.bin_op_def = BinOpDef(self)
def __str__(self):
return f"BinOp<a: {self.lhs.id()}, b: {self.rhs.id()}, op: {self.op}>"
a = self.lhs.id() if isinstance(self.lhs, BinOp) else self.lhs
b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs
return f"BinOp<a: {a}, b: {b}, op: {self.op}>"
def match(self, bin_op):
return self.lhs == bin_op.lhs and \
......
......@@ -21,6 +21,9 @@ class Iter(ASTTerm):
def id(self):
return self.iter_id
def name(self):
return f"i{self.iter_id}"
def type(self):
return Type_Int
......
from ast.bin_op import BinOp
from ast.loops import For, While
from ast.mutator import Mutator
from ast.visitor import Visitor
......@@ -7,6 +8,20 @@ class SetBlockVariants(Mutator):
def __init__(self, ast):
super().__init__(ast)
self.in_assignment = None
self.blocks = []
def push_variant(self, ast_node):
if self.in_assignment is not None:
for block in self.blocks:
block.add_variant(ast_node.name())
return ast_node
def mutate_Block(self, ast_node):
self.blocks.append(ast_node)
ast_node.stmts = [self.mutate(s) for s in ast_node.stmts]
self.blocks.pop()
return ast_node
def mutate_Assign(self, ast_node):
self.in_assignment = ast_node if ast_node.parent_block is not None else None
......@@ -16,34 +31,23 @@ class SetBlockVariants(Mutator):
return ast_node
def mutate_For(self, ast_node):
ast_node.block.add_variant(id(ast_node.iterator))
self.push_variant(ast_node.iterator)
ast_node.block.add_variant(ast_node.iterator.name())
ast_node.iterator = self.mutate(ast_node.iterator)
ast_node.block = self.mutate(ast_node.block)
return ast_node
def mutate_Array(self, ast_node):
if self.in_assignment is not None:
self.in_assignment.parent_block.add_variant(id(ast_node))
return ast_node
return self.push_variant(ast_node)
def mutate_Iter(self, ast_node):
if self.in_assignment is not None:
self.in_assignment.parent_block.add_variant(id(ast_node))
return ast_node
return self.push_variant(ast_node)
def mutate_Property(self, ast_node):
if self.in_assignment is not None:
self.in_assignment.parent_block.add_variant(id(ast_node))
return self.push_variant(ast_node)
return ast_node
def mutate_Variable(self, ast_node):
if self.in_assignment is not None:
self.in_assignment.parent_block.add_variant(id(ast_node))
return ast_node
def mutate_Var(self, ast_node):
return self.push_variant(ast_node)
class SetParentBlock(Visitor):
......@@ -54,7 +58,7 @@ class SetParentBlock(Visitor):
def current_block(self):
return self.blocks[-1] if self.blocks else None
def visit_Assign(self, ast_node):
def set_parent_block(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
......@@ -64,37 +68,32 @@ class SetParentBlock(Visitor):
self.visit_children(ast_node)
self.blocks.pop()
def visit_Assign(self, ast_node):
self.set_parent_block(ast_node)
def visit_BinOpDef(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_Branch(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_Filter(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_For(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_ParticleFor(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_Malloc(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_Realloc(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def visit_While(self, ast_node):
ast_node.parent_block = self.current_block()
self.visit_children(ast_node)
self.set_parent_block(ast_node)
def get_loop_parent_block(self, ast_node):
assert isinstance(ast_node, (For, While)), "Node must be a loop!"
......@@ -107,26 +106,26 @@ class SetBinOpTerminals(Visitor):
super().__init__(ast)
self.bin_ops = []
def push_terminal(self, ast_node):
for bin_op in self.bin_ops:
bin_op.add_terminal(ast_node.name())
def visit_BinOp(self, ast_node):
self.bin_ops.append(ast_node)
self.visit_children(ast_node)
self.bin_ops.pop()
def visit_Array(self, ast_node):
for bin_op in self.bin_ops:
bin_op.add_terminal(id(ast_node))
self.push_terminal(ast_node)
def visit_Iter(self, ast_node):
for bin_op in self.bin_ops:
bin_op.add_terminal(id(ast_node))
self.push_terminal(ast_node)
def visit_Property(self, ast_node):
for bin_op in self.bin_ops:
bin_op.add_terminal(id(ast_node))
self.push_terminal(ast_node)
def visit_Variable(self, ast_node):
for bin_op in self.bin_ops:
bin_op.add_terminal(id(ast_node))
def visit_Var(self, ast_node):
self.push_terminal(ast_node)
class LICM(Mutator):
......@@ -143,20 +142,28 @@ class LICM(Mutator):
self.loops.pop()
return ast_node
def mutate_While(self, ast_node):
self.lifts[id(ast_node)] = []
self.loops.append(ast_node)
ast_node.cond = self.mutate(ast_node.cond)
ast_node.block = self.mutate(ast_node.block)
self.loops.pop()
return ast_node
def mutate_BinOpDef(self, ast_node):
if self.loops:
if self.loops and isinstance(ast_node.bin_op, BinOp):
last_loop = self.loops[-1]
print(f"Checking lifting for {ast_node.id()}")
if not last_loop.block.variants.intersect(ast_node.bin_op.terminals):
#print(f"variants = {last_loop.block.variants}, terminals = {ast_node.bin_op.terminals}")
if not last_loop.block.variants.intersection(ast_node.bin_op.terminals):
print(f'lifting {ast_node.bin_op.id()}')
self.lifts[id(last_loop)].append(ast_node)
print(f"Lifting {ast_node.id()}")
return None
return ast_node
def mutate_Block(self, ast_node):
new_stmts = []
stmts = self.mutate(ast_node.stmts)
stmts = [self.mutate(s) for s in ast_node.stmts]
for s in stmts:
if s is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment