diff --git a/ast/visitor.py b/ast/visitor.py index ead043e0fc5e888d69741bc9a18797dcb9cd2746..8cf72e97d3a6ce6952a83f9e50767f7fa272f276 100644 --- a/ast/visitor.py +++ b/ast/visitor.py @@ -1,8 +1,9 @@ class Visitor: - def __init__(self, ast, enter_fn, leave_fn): + def __init__(self, ast, enter_fn=None, leave_fn=None, max_depth=0): self.ast = ast self.enter_fn = enter_fn self.leave_fn = leave_fn + self.max_depth = max_depth def visit(self): self.visit_rec(self.ast) @@ -28,11 +29,12 @@ class Visitor: return ast_list - def yield_elements(self, ast): + def yield_elements(ast, depth, max_depth): yield ast - - for c in ast.children(): - self.yield_elements(c) + if depth < max_depth or max_depth == 0: + for child in ast.children(): + for child_node in Visitor.yield_elements(child, depth + 1, max_depth): + yield child_node def __iter__(self): - self.yield_elements(self.ast) + yield from Visitor.yield_elements(self.ast, 0, self.max_depth) diff --git a/graph/graphviz.py b/graph/graphviz.py new file mode 100644 index 0000000000000000000000000000000000000000..28be57f23f0ed9bb8b9c145098b00f1aa1ed00ee --- /dev/null +++ b/graph/graphviz.py @@ -0,0 +1,48 @@ +from ast.arrays import Array +from ast.expr import BinOp, BinOpDef +from ast.lit import Lit +from ast.loops import Iter +from ast.properties import Property +from ast.variables import Var +from ast.visitor import Visitor +from graphviz import Digraph + +class ASTGraph: + def __init__(self, ast_node, filename, ref="G", max_depth=0): + self.graph = Digraph(ref, filename, node_attr={'color': 'lightblue2', 'style': 'filled'}) + self.graph.attr(size='6,6') + self.visitor = Visitor(ast_node, max_depth=max_depth) + + def generate_and_view(self): + def generate_edges_for_node(ast_node, graph, generated): + node_id = id(ast_node) + if not isinstance(ast_node, BinOpDef) and node_id not in generated: + node_ref = f"n{id(ast_node)}" + generated.append(node_id) + graph.node(node_ref, label=ASTGraph.get_node_label(ast_node)) + for child in ast_node.children(): + if not isinstance(child, BinOpDef): + child_ref = f"n{id(child)}" + graph.node(child_ref, label=ASTGraph.get_node_label(child)) + graph.edge(node_ref, child_ref) + + generated = [] + for node in self.visitor: + generate_edges_for_node(node, self.graph, generated) + + self.graph.view() + + def get_node_label(ast_node): + if isinstance(ast_node, (Array, Property, Var)): + return ast_node.name() + + if isinstance(ast_node, BinOp): + return ast_node.operator() + + if isinstance(ast_node, Iter): + return f"Iter({ast_node.id()})" + + if isinstance(ast_node, Lit): + return str(ast_node.value) + + return type(ast_node).__name__