diff --git a/ast/visitor.py b/ast/visitor.py index da2ff9c04b4827a115162144df995956fb4916a1..b40a8d52d41632df430738b116048d5666b3d1e0 100644 --- a/ast/visitor.py +++ b/ast/visitor.py @@ -19,9 +19,9 @@ class Visitor: if method is not None: method(ast_node) else: - self.keep_visiting(ast_node) + self.visit_children(ast_node) - def keep_visiting(self, ast_node): + def visit_children(self, ast_node): for c in ast_node.children(): self.visit(c) diff --git a/graph/graphviz.py b/graph/graphviz.py index 21a7e8195ecb3e03fa92860d9f49ecbeee01528c..7af2c1eae86b231e01c817e30e759d1be9ecd7d8 100644 --- a/graph/graphviz.py +++ b/graph/graphviz.py @@ -14,7 +14,7 @@ class ASTGraph: self.graph.attr(size='6,6') self.visitor = Visitor(ast_node, max_depth=max_depth) - def generate_and_view(self): + def render(self): def generate_edges_for_node(ast_node, graph, generated): node_id = id(ast_node) if not isinstance(ast_node, BinOpDef) and node_id not in generated: @@ -31,6 +31,9 @@ class ASTGraph: for node in self.visitor: generate_edges_for_node(node, self.graph, generated) + self.graph.render() + + def view(self): self.graph.view() def get_node_label(ast_node): diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index 53f8ef1716087fa78291eaabbb7702b4ac629ecb..ae2f814c0ad8e90dfa6c031962474af1e53e0253 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -199,5 +199,5 @@ class ParticleSimulation: flatten_property_accesses(program) simplify_expressions(program) - ASTGraph(self.kernels.lower(), "kernels").generate_and_view() + ASTGraph(self.kernels.lower(), "kernels").render() self.code_gen.generate_program(self, program) diff --git a/transformations/LICM.py b/transformations/LICM.py index a0edaf702142b0584b33f9800b32ff77a92340c3..07517b50cb4ddbab03925e119738244efc4b9512 100644 --- a/transformations/LICM.py +++ b/transformations/LICM.py @@ -55,37 +55,45 @@ class SetParentBlock(Visitor): def visit_Assign(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) def visit_Block(self, ast_node): ast_node.parent_block = self.current_block self.blocks.append(ast_node) - self.keep_visiting(ast_node) + self.visit_children(ast_node) self.blocks.pop() def visit_BinOpDef(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) def visit_Branch(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) + + def visit_Filter(self, ast_node): + ast_node.parent_block = self.current_block + self.visit_children(ast_node) def visit_For(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) + + def visit_ParticleFor(self, ast_node): + ast_node.parent_block = self.current_block + self.visit_children(ast_node) def visit_Malloc(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) def visit_Realloc(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) def visit_While(self, ast_node): ast_node.parent_block = self.current_block - self.keep_visiting(ast_node) + self.visit_children(ast_node) def get_loop_parent_block(self, ast_node): assert isinstance(ast_node, (For, While)), "Node must be a loop!" @@ -113,5 +121,5 @@ def move_loop_invariant_code(ast): set_parent_block.visit() set_block_variants = SetBlockVariants(ast) set_block_variants.mutate() - licm = LICM(ast, set_loop_parents) + licm = LICM(ast) licm.mutate()