Skip to content
Snippets Groups Projects
Commit 0d946ca1 authored by Rafael Ravedutti Lucio Machado's avatar Rafael Ravedutti Lucio Machado
Browse files

Write first version of graph output

parent a2e1f1f4
No related branches found
No related tags found
No related merge requests found
class Visitor:
def __init__(self, ast, enter_fn, leave_fn):
def __init__(self, ast, enter_fn=None, leave_fn=None, max_depth=0):
self.ast = ast
self.enter_fn = enter_fn
self.leave_fn = leave_fn
self.max_depth = max_depth
def visit(self):
self.visit_rec(self.ast)
......@@ -28,11 +29,12 @@ class Visitor:
return ast_list
def yield_elements(self, ast):
def yield_elements(ast, depth, max_depth):
yield ast
for c in ast.children():
self.yield_elements(c)
if depth < max_depth or max_depth == 0:
for child in ast.children():
for child_node in Visitor.yield_elements(child, depth + 1, max_depth):
yield child_node
def __iter__(self):
self.yield_elements(self.ast)
yield from Visitor.yield_elements(self.ast, 0, self.max_depth)
from ast.arrays import Array
from ast.expr import BinOp, BinOpDef
from ast.lit import Lit
from ast.loops import Iter
from ast.properties import Property
from ast.variables import Var
from ast.visitor import Visitor
from graphviz import Digraph
class ASTGraph:
def __init__(self, ast_node, filename, ref="G", max_depth=0):
self.graph = Digraph(ref, filename, node_attr={'color': 'lightblue2', 'style': 'filled'})
self.graph.attr(size='6,6')
self.visitor = Visitor(ast_node, max_depth=max_depth)
def generate_and_view(self):
def generate_edges_for_node(ast_node, graph, generated):
node_id = id(ast_node)
if not isinstance(ast_node, BinOpDef) and node_id not in generated:
node_ref = f"n{id(ast_node)}"
generated.append(node_id)
graph.node(node_ref, label=ASTGraph.get_node_label(ast_node))
for child in ast_node.children():
if not isinstance(child, BinOpDef):
child_ref = f"n{id(child)}"
graph.node(child_ref, label=ASTGraph.get_node_label(child))
graph.edge(node_ref, child_ref)
generated = []
for node in self.visitor:
generate_edges_for_node(node, self.graph, generated)
self.graph.view()
def get_node_label(ast_node):
if isinstance(ast_node, (Array, Property, Var)):
return ast_node.name()
if isinstance(ast_node, BinOp):
return ast_node.operator()
if isinstance(ast_node, Iter):
return f"Iter({ast_node.id()})"
if isinstance(ast_node, Lit):
return str(ast_node.value)
return type(ast_node).__name__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment