Created
July 6, 2020 08:37
-
-
Save fredyr/d54a2e4731c2e3175b28b47069535da0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Takes a CSV distance matrix and renders it as a MST.""" | |
import sys | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import pandas | |
from scipy.sparse.csgraph import minimum_spanning_tree | |
def read_distmat(f): | |
return pandas.read_csv(f, sep=',', index_col=0) | |
def mst(df, thres=None): | |
dst = df.values | |
mst = minimum_spanning_tree(dst) | |
G = nx.from_scipy_sparse_matrix(mst) | |
# Remove edges above threshold to get clustering | |
if thres: | |
edges = [(i, j) for i, j in G.edges() if G[i][j]['weight'] > thres] | |
for edge in edges: | |
G.remove_edge(*edge) | |
layout = nx.spring_layout(G) | |
nx.draw(G, pos=layout, node_size=1024, node_color='skyblue') | |
labels = {idx: label for idx, label in enumerate(df.columns.values)} | |
nx.draw_networkx_labels(G, pos=layout, labels=labels, font_size=6) | |
edge_labels = {(u, v): int(d['weight']) for u, v, d in G.edges(data=True)} | |
nx.draw_networkx_edge_labels(G, pos=layout, edge_labels=edge_labels, font_size=6) | |
plt.draw() | |
plt.show() | |
# TODO: values is deprecated, and replaced by .to_numpy(), should upgrade my | |
# pandas version | |
f = sys.argv[1] | |
mst(read_distmat(f)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment