From 603dafd527a65245d24873d49469c6706ddffcf6 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de>
Date: Fri, 15 Jan 2021 02:26:20 +0100
Subject: [PATCH] Fix LICM visitors and just render graph instead of viewing it

Signed-off-by: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de>
---
 ast/visitor.py             |  4 ++--
 graph/graphviz.py          |  5 ++++-
 sim/particle_simulation.py |  2 +-
 transformations/LICM.py    | 26 +++++++++++++++++---------
 4 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/ast/visitor.py b/ast/visitor.py
index da2ff9c..b40a8d5 100644
--- a/ast/visitor.py
+++ b/ast/visitor.py
@@ -19,9 +19,9 @@ class Visitor:
         if method is not None:
             method(ast_node)
         else:
-            self.keep_visiting(ast_node)
+            self.visit_children(ast_node)
 
-    def keep_visiting(self, ast_node):
+    def visit_children(self, ast_node):
         for c in ast_node.children():
             self.visit(c)
 
diff --git a/graph/graphviz.py b/graph/graphviz.py
index 21a7e81..7af2c1e 100644
--- a/graph/graphviz.py
+++ b/graph/graphviz.py
@@ -14,7 +14,7 @@ class ASTGraph:
         self.graph.attr(size='6,6')
         self.visitor = Visitor(ast_node, max_depth=max_depth)
 
-    def generate_and_view(self):
+    def render(self):
         def generate_edges_for_node(ast_node, graph, generated):
             node_id = id(ast_node)
             if not isinstance(ast_node, BinOpDef) and node_id not in generated:
@@ -31,6 +31,9 @@ class ASTGraph:
         for node in self.visitor:
             generate_edges_for_node(node, self.graph, generated)
 
+        self.graph.render()
+
+    def view(self):
         self.graph.view()
 
     def get_node_label(ast_node):
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index 53f8ef1..ae2f814 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -199,5 +199,5 @@ class ParticleSimulation:
         flatten_property_accesses(program)
         simplify_expressions(program)
 
-        ASTGraph(self.kernels.lower(), "kernels").generate_and_view()
+        ASTGraph(self.kernels.lower(), "kernels").render()
         self.code_gen.generate_program(self, program)
diff --git a/transformations/LICM.py b/transformations/LICM.py
index a0edaf7..07517b5 100644
--- a/transformations/LICM.py
+++ b/transformations/LICM.py
@@ -55,37 +55,45 @@ class SetParentBlock(Visitor):
 
     def visit_Assign(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
 
     def visit_Block(self, ast_node):
         ast_node.parent_block = self.current_block
         self.blocks.append(ast_node)
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
         self.blocks.pop()
 
     def visit_BinOpDef(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
 
     def visit_Branch(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
+
+    def visit_Filter(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.visit_children(ast_node)
 
     def visit_For(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
+
+    def visit_ParticleFor(self, ast_node):
+        ast_node.parent_block = self.current_block
+        self.visit_children(ast_node)
 
     def visit_Malloc(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
 
     def visit_Realloc(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
 
     def visit_While(self, ast_node):
         ast_node.parent_block = self.current_block
-        self.keep_visiting(ast_node)
+        self.visit_children(ast_node)
 
     def get_loop_parent_block(self, ast_node):
         assert isinstance(ast_node, (For, While)), "Node must be a loop!"
@@ -113,5 +121,5 @@ def move_loop_invariant_code(ast):
     set_parent_block.visit()
     set_block_variants = SetBlockVariants(ast)
     set_block_variants.mutate()
-    licm = LICM(ast, set_loop_parents)
+    licm = LICM(ast)
     licm.mutate()
-- 
GitLab