Last active
May 9, 2023 16:22
-
-
Save jgabriellima/a4ed73df4aa28cb9980520b6de5fe8cf to your computer and use it in GitHub Desktop.
This code create a call flow graph from a package
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import base64 | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import subprocess | |
import os | |
class FlowchartAnalyzer(ast.NodeVisitor): | |
def __init__(self): | |
self.current_function = None | |
self.flowchart = {} | |
def visit_FunctionDef(self, node): | |
self.current_function = node.name | |
self.flowchart[self.current_function] = [] | |
self.generic_visit(node) | |
self.current_function = None | |
def visit_If(self, node): | |
condition = ast.dump(node.test) | |
self.flowchart[self.current_function].append(('if', condition)) | |
self.generic_visit(node) | |
def visit_While(self, node): | |
condition = ast.dump(node.test) | |
self.flowchart[self.current_function].append(('while', condition)) | |
self.generic_visit(node) | |
def visit_For(self, node): | |
iterable = ast.dump(node.iter) | |
self.flowchart[self.current_function].append(('for', iterable)) | |
self.generic_visit(node) | |
def flowchart_to_mermaid(flowchart): | |
mermaid_text = "graph TD\n" | |
for function, controls in flowchart.items(): | |
for i, control in enumerate(controls): | |
control_type, condition = control | |
# Escape special characters in the condition | |
condition = condition.replace("\n", "\\n").replace("'", "\\'") | |
# Give each node a unique ID and specify its text separately | |
node_id = f"{function}_{i}" | |
node_text = f"{control_type} {condition}" | |
mermaid_text += f"{node_id}[\"{node_text}\"]\n" | |
if i > 0: | |
mermaid_text += f"{function}_{i - 1} --> {node_id}\n" | |
return mermaid_text | |
def get_flowchart(pkg_path): | |
analyzer = FlowchartAnalyzer() | |
for root, _, files in os.walk(pkg_path): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
with open(file_path, "r", encoding="utf-8") as f: | |
try: | |
file_ast = ast.parse(f.read()) | |
analyzer.visit(file_ast) | |
except Exception as e: | |
print(f"Error parsing {file_path}: {e}") | |
return analyzer.flowchart | |
def clone_github_repo(github_link, local_dir): | |
if not os.path.exists(local_dir): | |
os.makedirs(local_dir) | |
subprocess.run(["git", "clone", github_link, local_dir]) | |
class CallGraphAnalyzer(ast.NodeVisitor): | |
def __init__(self, call_graph, exclude_libs=False): | |
self.call_graph = call_graph | |
self.current_function = None | |
self.exclude_libs = exclude_libs | |
self.functions = set() | |
self.current_module = "" | |
self.current_class = None | |
def visit_ClassDef(self, node): | |
self.current_class = node.name | |
self.generic_visit(node) | |
self.current_class = None | |
def visit_FunctionDef(self, node): | |
if self.current_class: | |
func_name = f"{self.current_module}.{self.current_class}.{node.name}" | |
else: | |
func_name = f"{self.current_module}.{node.name}" | |
self.functions.add(func_name) | |
self.current_function = func_name | |
self.generic_visit(node) | |
def visit_Call(self, node): | |
func_name = None | |
if isinstance(node.func, ast.Name): | |
if node.func.id == 'self' and self.current_class: | |
func_name = f"{self.current_module}.{self.current_class}" | |
else: | |
func_name = f"{self.current_module}.{node.func.id}" | |
elif isinstance(node.func, ast.Attribute): | |
if isinstance(node.func.value, ast.Name): | |
if node.func.value.id == 'self' and self.current_class: | |
func_name = f"{self.current_module}.{self.current_class}.{node.func.attr}" | |
else: | |
func_name = f"{node.func.value.id}.{node.func.attr}" | |
if func_name and (not self.exclude_libs or (self.exclude_libs and func_name in self.functions)): | |
self.call_graph.add_edge(self.current_function, func_name) | |
self.generic_visit(node) | |
class ClassDiagramAnalyzer(ast.NodeVisitor): | |
def __init__(self): | |
self.classes = {} | |
self.inheritance = {} | |
def visit_ClassDef(self, node): | |
class_name = node.name | |
methods = [n.name for n in node.body if isinstance(n, ast.FunctionDef)] | |
self.classes[class_name] = methods | |
for base in node.bases: | |
if isinstance(base, ast.Name): | |
base_name = base.id | |
self.inheritance[class_name] = base_name | |
self.generic_visit(node) | |
def mm(graph): | |
graphbytes = graph.encode("ascii") | |
base64_bytes = base64.b64encode(graphbytes) | |
base64_string = base64_bytes.decode("ascii") | |
print("https://mermaid.ink/img/" + base64_string) | |
def get_call_flow_graph_package(pkg_path, exclude_libs=False): | |
call_graph = nx.DiGraph() | |
analyzer = CallGraphAnalyzer(call_graph, exclude_libs) | |
for root, _, files in os.walk(pkg_path): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
with open(file_path, "r", encoding="utf-8") as f: | |
try: | |
file_ast = ast.parse(f.read()) | |
analyzer.current_module = os.path.splitext(file)[0] | |
analyzer.visit(file_ast) | |
except Exception as e: | |
print(f"Error parsing {file_path}: {e}") | |
return call_graph | |
def get_call_flow_graph(file_path, class_or_function_name, exclude_libs=False): | |
call_graph = nx.DiGraph() | |
analyzer = CallGraphAnalyzer(call_graph, exclude_libs) | |
with open(file_path, "r", encoding="utf-8") as f: | |
try: | |
file_ast = ast.parse(f.read()) | |
for node in ast.walk(file_ast): | |
if ((isinstance(node, ast.ClassDef) or isinstance(node, ast.FunctionDef)) | |
and node.name == class_or_function_name): | |
analyzer.visit(node) | |
except Exception as e: | |
print(f"Error parsing {file_path}: {e}") | |
return call_graph | |
def get_class_diagram(pkg_path): | |
analyzer = ClassDiagramAnalyzer() | |
for root, _, files in os.walk(pkg_path): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
with open(file_path, "r", encoding="utf-8") as f: | |
try: | |
file_ast = ast.parse(f.read()) | |
analyzer.visit(file_ast) | |
except Exception as e: | |
print(f"Error parsing {file_path}: {e}") | |
return analyzer.classes, analyzer.inheritance | |
def draw_call_flow_graph(call_graph): | |
pos = nx.spring_layout(call_graph) | |
nx.draw(call_graph, pos, with_labels=True, node_size=2000, font_size=10, font_weight='bold', arrows=True) | |
plt.show() | |
def graph_to_text(call_graph): | |
text = "Call Flow Graph:\n" | |
for caller, callees in call_graph.adj.items(): | |
for callee in callees: | |
text += f"{caller} calls {callee}\n" | |
return text | |
def text_to_mermaid(text, orientation="TB"): | |
lines = text.split("\n") | |
mermaid_text = f"graph {orientation}\n" | |
for line in lines: | |
if "calls" in line: | |
caller, _, callee = line.split() | |
mermaid_text += f"{caller} --> {callee}\n" | |
return mermaid_text | |
def class_diagram_to_text(class_diagram, inheritance): | |
text = "Class Diagram:\n" | |
for class_name, methods in class_diagram.items(): | |
text += f"Class {class_name}:\n" | |
for method in methods: | |
text += f" Method {method}\n" | |
for subclass, superclass in inheritance.items(): | |
text += f"{subclass} inherits from {superclass}\n" | |
return text | |
def class_diagram_to_mermaid(class_diagram, inheritance): | |
mermaid_text = "classDiagram\n" | |
for class_name, methods in class_diagram.items(): | |
mermaid_text += f"class {class_name} {{\n" | |
for method in methods: | |
mermaid_text += f" {method}()\n" | |
mermaid_text += "}\n" | |
for subclass, superclass in inheritance.items(): | |
mermaid_text += f"{subclass} --|> {superclass}\n" | |
return mermaid_text | |
class SequenceDiagramAnalyzer(ast.NodeVisitor): | |
def __init__(self): | |
self.classes = {} | |
self.current_class = None | |
self.current_function = None | |
self.user_defined_functions = set() | |
def visit_ClassDef(self, node): | |
self.current_class = node.name | |
self.classes[self.current_class] = [] | |
self.generic_visit(node) | |
self.current_class = None | |
def visit_FunctionDef(self, node): | |
self.current_function = node.name | |
if self.current_class: | |
self.user_defined_functions.add(f"{self.current_class}.{self.current_function}") | |
else: | |
self.user_defined_functions.add(self.current_function) | |
self.generic_visit(node) | |
self.current_function = None | |
def visit_Call(self, node): | |
if isinstance(node.func, ast.Attribute): | |
if isinstance(node.func.value, ast.Name): | |
if node.func.value.id == 'self': | |
called_class = self.current_class | |
else: | |
called_class = node.func.value.id | |
called_function = node.func.attr | |
if f"{called_class}.{called_function}" in self.user_defined_functions: | |
self.classes[self.current_class].append((self.current_function, called_class, called_function)) | |
self.generic_visit(node) | |
def sequence_diagram_to_mermaid(sequence_diagram): | |
mermaid_text = "sequenceDiagram\n" | |
for class_name, calls in sequence_diagram.items(): | |
for call in calls: | |
current_function, called_class, called_function = call | |
mermaid_text += f"{class_name}.{current_function} ->> {called_class}.{called_function}: call\n" | |
return mermaid_text | |
def get_sequence_diagram(pkg_path): | |
analyzer = SequenceDiagramAnalyzer() | |
for root, _, files in os.walk(pkg_path): | |
for file in files: | |
if file.endswith(".py"): | |
file_path = os.path.join(root, file) | |
with open(file_path, "r", encoding="utf-8") as f: | |
try: | |
file_ast = ast.parse(f.read()) | |
analyzer.visit(file_ast) | |
except Exception as e: | |
print(f"Error parsing {file_path}: {e}") | |
return analyzer.classes | |
if __name__ == '__main__': | |
package_path = "" | |
graph = get_call_flow_graph(package_path, 'ContractProcessor', exclude_libs=True) | |
draw_call_flow_graph(graph) | |
# print(graph_to_text(graph)) | |
print(text_to_mermaid(graph_to_text(graph), 'LR')) | |
class_diagram, inheritance = get_class_diagram(package_path) | |
# print(class_diagram_to_text(class_diagram, inheritance)) | |
print(class_diagram_to_mermaid(class_diagram, inheritance)) | |
# sequence_diagram = get_sequence_diagram(package_path) | |
# print(sequence_diagram_to_mermaid(sequence_diagram)) | |
# github_link = "https://github.com/username/repo.git" | |
# local_dir = "/path/to/local/dir" | |
# clone_github_repo(github_link, local_dir) | |
flowchart = get_flowchart(package_path) | |
print(flowchart_to_mermaid(flowchart)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment