diff --git a/ast/assign.py b/ast/assign.py index a6cc919c8e528b16f6160549c2461ef9c05fa08d..08f6a7200bc61d523a4e05375718342cd33a9130 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -7,7 +7,6 @@ from functools import reduce class Assign(ASTNode): def __init__(self, sim, dest, src): super().__init__(sim) - self.parent_block = None self.type = dest.type() src = as_lit_ast(sim, src) @@ -16,18 +15,19 @@ class Assign(ASTNode): for i in range(0, sim.dimensions): from ast.bin_op import BinOp - dsrc = (src if (not isinstance(src, BinOp) or - src.type() != Type_Vector) - else src[i]) - - self.assignments.append((dest[i], dsrc)) + dim_src = src if not isinstance(src, BinOp) or src.type() != Type_Vector else src[i] + self.assignments.append((dest[i], dim_src)) else: self.assignments = [(dest, src)] def __str__(self): return f"Assign<{self.assignments}>" + def destinations(self): + return [a[0] for a in self.assignments] + + def sources(self): + return [a[1] for a in self.assignments] + def children(self): - return reduce((lambda x, y: x + y), [ - [self.assignments[i][0], self.assignments[i][1]] - for i in range(0, len(self.assignments))]) + return reduce((lambda x, y: x + y), [[a[0], a[1]] for a in self.assignments]) diff --git a/ast/ast_node.py b/ast/ast_node.py index e58afc26693907e38fd322ff5969d7a37271e675..8d37b6961bd96c0268fc92581d90aec561c3832b 100644 --- a/ast/ast_node.py +++ b/ast/ast_node.py @@ -4,6 +4,7 @@ from ast.data_types import Type_Invalid class ASTNode: def __init__(self, sim): self.sim = sim + self._parent_block = None # Set during SetParentBlock transformation def __str__(self): return "ASTNode<>" @@ -14,5 +15,9 @@ class ASTNode: def scope(self): return self.sim.global_scope + @property + def parent_block(self): + return self._parent_block + def children(self): return [] diff --git a/ast/block.py b/ast/block.py index e85facec636e02ee94235d762cef1f85f355db07..85ddd1fe13f644a47caf0e92c803232e4e714a9a 100644 --- a/ast/block.py +++ b/ast/block.py @@ -5,7 +5,7 @@ class Block(ASTNode): def __init__(self, sim, stmts): super().__init__(sim) self.level = 0 - self.expressions = [] + self.variants = [] if isinstance(stmts, Block): self.stmts = stmts.statements() @@ -29,24 +29,19 @@ class Block(ASTNode): def add_statement(self, stmt): if isinstance(stmt, list): - for s in stmt: - s.parent_block = self - self.stmts = self.stmts + stmt - else: - stmt.parent_block = self self.stmts.append(stmt) + def add_variant(self, variant): + if isinstance(variant, list): + self.variants = self.variants + variant + else: + self.variants.append(variant) + def statements(self): return self.stmts - def add_expression(self, expr): - if isinstance(expr, list): - self.expressions = self.expressions + expr - else: - self.expressions.append(expr) - def children(self): return self.stmts diff --git a/ast/branches.py b/ast/branches.py index a6302609d42b7ed6b25513a2ce66bdbc5717e027..209dd12b67fcb43c15015d628b682a40d6c8a9fc 100644 --- a/ast/branches.py +++ b/ast/branches.py @@ -6,7 +6,6 @@ from ast.lit import as_lit_ast class Branch(ASTNode): def __init__(self, sim, cond, one_way=False, blk_if=None, blk_else=None): self.sim = sim - self.parent_block = None self.cond = as_lit_ast(sim, cond) self.switch = True self.block_if = Block(sim, []) if blk_if is None else blk_if diff --git a/ast/loops.py b/ast/loops.py index e0c645cfe256a281b509a14f1d4df8487da911b1..2dc06ee35db21c2ce2f683c1f974197ea10dc8d8 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -46,7 +46,6 @@ class For(ASTNode): self.iterator = Iter(sim, self) self.min = as_lit_ast(sim, range_min) self.max = as_lit_ast(sim, range_max) - self.parent_block = None self.block = Block(sim, []) if block is None else block def __str__(self): @@ -80,7 +79,6 @@ class ParticleFor(For): class While(ASTNode): def __init__(self, sim, cond, block=None): super().__init__(sim) - self.parent_block = None self.cond = BinOp.inline(cond) self.block = Block(sim, []) if block is None else block @@ -103,7 +101,6 @@ class While(ASTNode): class NeighborFor(): def __init__(self, sim, particle, cell_lists): self.sim = sim - self.parent_block = None self.particle = particle self.cell_lists = cell_lists diff --git a/ast/memory.py b/ast/memory.py index 3458bf9fa464a1a4cde36b555b85324fff40b4b1..e853a0dd77b6df1827b6eeaf2dd0034409d0263d 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -8,7 +8,6 @@ import operator class Malloc(ASTNode): def __init__(self, sim, array, sizes, decl=False): super().__init__(sim) - self.parent_block = None self.array = array self.decl = decl self.prim_size = Sizeof(sim, array.type()) @@ -22,7 +21,6 @@ class Malloc(ASTNode): class Realloc(ASTNode): def __init__(self, sim, array, size): super().__init__(sim) - self.parent_block = None self.array = array self.prim_size = Sizeof(sim, array.type()) self.size = BinOp.inline(self.prim_size * size) diff --git a/ast/mutator.py b/ast/mutator.py index f450773eb5081d27f4fa8f5f1df8546024802323..6ec8727c7713586dfc4653c3b9db9890b47b29e1 100644 --- a/ast/mutator.py +++ b/ast/mutator.py @@ -1,66 +1,19 @@ -from ast.arrays import ArrayAccess -from ast.assign import Assign -from ast.bin_op import BinOp, BinOpDef -from ast.block import Block -from ast.branches import Branch -from ast.cast import Cast -from ast.loops import For, While -from ast.math import Sqrt -from ast.memory import Malloc, Realloc -from ast.select import Select -from sim.timestep import Timestep - - class Mutator: def __init__(self, ast, max_depth=0): self.ast = ast self.max_depth = 0 + def get_method(self, method_name): + method = getattr(self, method_name, None) + return method if callable(method) else None + def mutate(self, ast_node=None): if ast_node is None: ast_node = self.ast - if isinstance(ast_node, ArrayAccess): - return self.mutate_ArrayAccess(ast_node) - - elif isinstance(ast_node, Assign): - return self.mutate_Assign(ast_node) - - elif isinstance(ast_node, BinOp): - return self.mutate_BinOp(ast_node) - - elif isinstance(ast_node, BinOpDef): - return self.mutate_BinOpDef(ast_node) - - elif isinstance(ast_node, Block): - return self.mutate_Block(ast_node) - - elif isinstance(ast_node, Branch): - return self.mutate_Branch(ast_node) - - elif isinstance(ast_node, Cast): - return self.mutate_Cast(ast_node) - - elif isinstance(ast_node, For): - return self.mutate_For(ast_node) - - elif isinstance(ast_node, Malloc): - return self.mutate_Malloc(ast_node) - - elif isinstance(ast_node, Realloc): - return self.mutate_Realloc(ast_node) - - elif isinstance(ast_node, Select): - return self.mutate_Select(ast_node) - - elif isinstance(ast_node, Sqrt): - return self.mutate_Sqrt(ast_node) - - elif isinstance(ast_node, Timestep): - return self.mutate_Timestep(ast_node) - - elif isinstance(ast_node, While): - return self.mutate_While(ast_node) + method = self.get_method(f"mutate_{type(ast_node).__name__}") + if method is not None: + return method(ast_node) return ast_node @@ -74,10 +27,7 @@ class Mutator: return ast_node def mutate_Assign(self, ast_node): - ast_node.assignments = [ - (self.mutate(ast_node.assignments[i][0]), self.mutate(ast_node.assignments[i][1])) - for i in range(0, len(ast_node.assignments)) - ] + ast_node.assignments = [(self.mutate(a[0]), self.mutate(a[1])) for a in ast_node.assignments] return ast_node def mutate_BinOp(self, ast_node): @@ -100,6 +50,9 @@ class Mutator: ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else) return ast_node + def mutate_Filter(self, ast_node): + return self.mutate_Branch(ast_node) + def mutate_Cast(self, ast_node): ast_node.expr = self.mutate(ast_node.expr) return ast_node @@ -109,6 +62,9 @@ class Mutator: ast_node.block = self.mutate(ast_node.block) return ast_node + def mutate_ParticleFor(self, ast_node): + return self.mutate_For(ast_node) + def mutate_Malloc(self, ast_node): ast_node.array = self.mutate(ast_node.array) ast_node.size = self.mutate(ast_node.size) diff --git a/ast/visitor.py b/ast/visitor.py index 419b5e7df9e69eb7992eb49300bd7d925e38fd7b..da2ff9c04b4827a115162144df995956fb4916a1 100644 --- a/ast/visitor.py +++ b/ast/visitor.py @@ -1,49 +1,53 @@ -from ast.bin_op import BinOp -from ast.loops import For, ParticleFor, While +from collections import deque class Visitor: - def __init__(self, ast, max_depth=0): + def __init__(self, ast, max_depth=0, breadth_first=False): self.ast = ast - self.max_depth = 0 + self.max_depth = max_depth + self.breadth_first = breadth_first - def visit(ast_node): - for c in ast_node.children(): - if isinstance(c, Array): - self.visit_Array(c) - elif isinstance(c, BinOp): - self.visit_BinOp(c) - elif isinstance(c, (For, ParticleFor, While)): - self.visit_Loop(c) - else: - self.visit(c) + def get_method(self, method_name): + method = getattr(self, method_name, None) + return method if callable(method) else None - def visit_Array(self, ast_node): - return self.visit(ast_node) + def visit(self, ast_node=None): + if ast_node is None: + ast_node = self.ast - def visit_BinOp(self, ast_node): - return self.visit(ast_node) + method = self.get_method(f"visit_{type(ast_node).__name__}") + if method is not None: + method(ast_node) + else: + self.keep_visiting(ast_node) - def visit_Loop(self, ast_node): - return self.visit(ast_node) + def keep_visiting(self, ast_node): + for c in ast_node.children(): + self.visit(c) - def list_ast(self): - self.list_elements(self.ast) + def yield_elements_breadth_first(self, ast_node=None): + nodes_to_visit = deque() - def list_elements(self, ast): - ast_list = [ast] + if ast_node is None: + ast_node = self.ast - for c in self.ast.children(): - ast_list += self.list_elements(c) + nodes_to_visit.append(ast_node) - return ast_list + while nodes_to_visit: + next_node = nodes_to_visit.popleft() # nodes_to_visit.pop() for depth-first traversal + yield next_node + for c in next_node.children(): + nodes_to_visit.append(c) - def yield_elements(ast, depth, max_depth): - yield ast - if depth < max_depth or max_depth == 0: + def yield_elements(self, ast, depth): + if depth < self.max_depth or self.max_depth == 0: for child in ast.children(): - for child_node in Visitor.yield_elements(child, depth + 1, max_depth): - yield child_node + yield child + yield from self.yield_elements(child, depth + 1) def __iter__(self): - yield from Visitor.yield_elements(self.ast, 0, self.max_depth) + if self.breadth_first: + yield from self.yield_elements_breadth_first(self.ast) + else: + yield self.ast + yield from self.yield_elements(self.ast, 1) diff --git a/code_gen/cgen.py b/code_gen/cgen.py index ea87d2a6cffda99eb69d73386ad1eadc7ee89a83..c28e353d69cf64fe502dd8000c14feac669f0dd0 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -183,7 +183,7 @@ class CGen: return f"({lhs} {ast_node.operator()} {rhs})" # Some expressions can be defined on-the-fly during transformations, hence they do not have - # a definition statement in the tree, so we generate them right before usage + # a definition statement in the tree, so we generate them right before use if not ast_node.generated: CGen.generate_statement(sim, ast_node.definition()) diff --git a/transformations/LICM.py b/transformations/LICM.py new file mode 100644 index 0000000000000000000000000000000000000000..a0edaf702142b0584b33f9800b32ff77a92340c3 --- /dev/null +++ b/transformations/LICM.py @@ -0,0 +1,117 @@ +from ast.mutator import Mutator +from ast.visitor import Visitor + + +class SetBlockVariants(Mutator): + def __init__(self, ast): + super().__init__(ast) + self.current_block = None + self.in_assignment = None + + def mutate_Block(self, ast_node): + self.current_block = ast_node + ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + return ast_node + + def mutate_Assign(self, ast_node): + self.in_assignment = ast_node + for dest in ast_node.destinations(): + self.mutate(dest) + self.in_assignment = None + return ast_node + + def mutate_Array(self, ast_node): + if self.in_assignment is not None: + self.in_assignment.parent_block.add_variant(ast_node) + + return ast_node + + def mutate_For(self, ast_node): + ast_node.block.add_variant(ast_node.iterator) + ast_node.iterator = self.mutate(ast_node.iterator) + ast_node.block = self.mutate(ast_node.block) + return ast_node + + def mutate_Property(self, ast_node): + if self.in_assignment is not None: + self.in_assignment.parent_block.add_variant(ast_node) + + return ast_node + + def mutate_Variable(self, ast_node): + if self.in_assignment is not None: + self.in_assignment.parent_block.add_variant(ast_node) + + return ast_node + + +class SetParentBlock(Visitor): + def __init__(self, ast): + super().__init__(ast) + self.blocks = [] + + def current_block(self): + return self.blocks[-1] + + def visit_Assign(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def visit_Block(self, ast_node): + ast_node.parent_block = self.current_block + self.blocks.append(ast_node) + self.keep_visiting(ast_node) + self.blocks.pop() + + def visit_BinOpDef(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def visit_Branch(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def visit_For(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def visit_Malloc(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def visit_Realloc(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def visit_While(self, ast_node): + ast_node.parent_block = self.current_block + self.keep_visiting(ast_node) + + def get_loop_parent_block(self, ast_node): + assert isinstance(ast_node, (For, While)), "Node must be a loop!" + loop_id = id(ast_node) + return self.parents[loop_id] if loop_id in self.parents else None + + +class LICM(Mutator): + def __init__(self, ast, loop_parents): + super().__init__(ast) + self.loop_parents = loop_parents + + def mutate_For(self, ast_node): + ast_node.iterator = self.mutate(ast_node.iterator) + ast_node.block = self.mutate(ast_node.block) + return ast_node + + def mutate_Block(self, ast_node): + ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + return ast_node + + +def move_loop_invariant_code(ast): + set_parent_block = SetParentBlock(ast) + set_parent_block.visit() + set_block_variants = SetBlockVariants(ast) + set_block_variants.mutate() + licm = LICM(ast, set_loop_parents) + licm.mutate()