diff --git a/ast/bin_op.py b/ast/bin_op.py index 6d48c6d72f2e5d03a4c82640c968ec523679c011..993c3ba39237e25aa4759c8bb6d70e594e5429dc 100644 --- a/ast/bin_op.py +++ b/ast/bin_op.py @@ -52,7 +52,9 @@ class BinOp(ASTNode): self.bin_op_def = BinOpDef(self) def __str__(self): - return f"BinOp<a: {self.lhs.id()}, b: {self.rhs.id()}, op: {self.op}>" + a = self.lhs.id() if isinstance(self.lhs, BinOp) else self.lhs + b = self.rhs.id() if isinstance(self.rhs, BinOp) else self.rhs + return f"BinOp<a: {a}, b: {b}, op: {self.op}>" def match(self, bin_op): return self.lhs == bin_op.lhs and \ diff --git a/ast/loops.py b/ast/loops.py index 2dc06ee35db21c2ce2f683c1f974197ea10dc8d8..82d85ad534f7f26a50d3f794562721081a064e85 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -21,6 +21,9 @@ class Iter(ASTTerm): def id(self): return self.iter_id + def name(self): + return f"i{self.iter_id}" + def type(self): return Type_Int diff --git a/transformations/LICM.py b/transformations/LICM.py index 4742f19adf852e065cb7ba65733b5ce837c2dbf3..e6e86fd1cdf689c04b6c09e710f9e45e7c684996 100644 --- a/transformations/LICM.py +++ b/transformations/LICM.py @@ -1,3 +1,4 @@ +from ast.bin_op import BinOp from ast.loops import For, While from ast.mutator import Mutator from ast.visitor import Visitor @@ -7,6 +8,20 @@ class SetBlockVariants(Mutator): def __init__(self, ast): super().__init__(ast) self.in_assignment = None + self.blocks = [] + + def push_variant(self, ast_node): + if self.in_assignment is not None: + for block in self.blocks: + block.add_variant(ast_node.name()) + + return ast_node + + def mutate_Block(self, ast_node): + self.blocks.append(ast_node) + ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + self.blocks.pop() + return ast_node def mutate_Assign(self, ast_node): self.in_assignment = ast_node if ast_node.parent_block is not None else None @@ -16,34 +31,23 @@ class SetBlockVariants(Mutator): return ast_node def mutate_For(self, ast_node): - ast_node.block.add_variant(id(ast_node.iterator)) + self.push_variant(ast_node.iterator) + ast_node.block.add_variant(ast_node.iterator.name()) ast_node.iterator = self.mutate(ast_node.iterator) ast_node.block = self.mutate(ast_node.block) return ast_node def mutate_Array(self, ast_node): - if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(id(ast_node)) - - return ast_node + return self.push_variant(ast_node) def mutate_Iter(self, ast_node): - if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(id(ast_node)) - - return ast_node + return self.push_variant(ast_node) def mutate_Property(self, ast_node): - if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(id(ast_node)) + return self.push_variant(ast_node) - return ast_node - - def mutate_Variable(self, ast_node): - if self.in_assignment is not None: - self.in_assignment.parent_block.add_variant(id(ast_node)) - - return ast_node + def mutate_Var(self, ast_node): + return self.push_variant(ast_node) class SetParentBlock(Visitor): @@ -54,7 +58,7 @@ class SetParentBlock(Visitor): def current_block(self): return self.blocks[-1] if self.blocks else None - def visit_Assign(self, ast_node): + def set_parent_block(self, ast_node): ast_node.parent_block = self.current_block() self.visit_children(ast_node) @@ -64,37 +68,32 @@ class SetParentBlock(Visitor): self.visit_children(ast_node) self.blocks.pop() + def visit_Assign(self, ast_node): + self.set_parent_block(ast_node) + def visit_BinOpDef(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_Branch(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_Filter(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_For(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_ParticleFor(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_Malloc(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_Realloc(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def visit_While(self, ast_node): - ast_node.parent_block = self.current_block() - self.visit_children(ast_node) + self.set_parent_block(ast_node) def get_loop_parent_block(self, ast_node): assert isinstance(ast_node, (For, While)), "Node must be a loop!" @@ -107,26 +106,26 @@ class SetBinOpTerminals(Visitor): super().__init__(ast) self.bin_ops = [] + def push_terminal(self, ast_node): + for bin_op in self.bin_ops: + bin_op.add_terminal(ast_node.name()) + def visit_BinOp(self, ast_node): self.bin_ops.append(ast_node) self.visit_children(ast_node) self.bin_ops.pop() def visit_Array(self, ast_node): - for bin_op in self.bin_ops: - bin_op.add_terminal(id(ast_node)) + self.push_terminal(ast_node) def visit_Iter(self, ast_node): - for bin_op in self.bin_ops: - bin_op.add_terminal(id(ast_node)) + self.push_terminal(ast_node) def visit_Property(self, ast_node): - for bin_op in self.bin_ops: - bin_op.add_terminal(id(ast_node)) + self.push_terminal(ast_node) - def visit_Variable(self, ast_node): - for bin_op in self.bin_ops: - bin_op.add_terminal(id(ast_node)) + def visit_Var(self, ast_node): + self.push_terminal(ast_node) class LICM(Mutator): @@ -143,20 +142,28 @@ class LICM(Mutator): self.loops.pop() return ast_node + def mutate_While(self, ast_node): + self.lifts[id(ast_node)] = [] + self.loops.append(ast_node) + ast_node.cond = self.mutate(ast_node.cond) + ast_node.block = self.mutate(ast_node.block) + self.loops.pop() + return ast_node + def mutate_BinOpDef(self, ast_node): - if self.loops: + if self.loops and isinstance(ast_node.bin_op, BinOp): last_loop = self.loops[-1] - print(f"Checking lifting for {ast_node.id()}") - if not last_loop.block.variants.intersect(ast_node.bin_op.terminals): + #print(f"variants = {last_loop.block.variants}, terminals = {ast_node.bin_op.terminals}") + if not last_loop.block.variants.intersection(ast_node.bin_op.terminals): + print(f'lifting {ast_node.bin_op.id()}') self.lifts[id(last_loop)].append(ast_node) - print(f"Lifting {ast_node.id()}") return None return ast_node def mutate_Block(self, ast_node): new_stmts = [] - stmts = self.mutate(ast_node.stmts) + stmts = [self.mutate(s) for s in ast_node.stmts] for s in stmts: if s is not None: