Created
January 14, 2019 07:28
-
-
Save PCJohn/ed503fad5c8b43e3da29ffc60b1313d3 to your computer and use it in GitHub Desktop.
Display graphs with networkx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import numpy as np | |
from matplotlib import pyplot as plt | |
import networkx as nx | |
import copy | |
class Graph: | |
def __init__(self,adj,labels): | |
self.adj = adj | |
self.char_list = labels | |
self.lab = dict(enumerate(labels)) | |
self.edge_labels = {(n1,n2) : adj[n1,n2] for n1,v1 in enumerate(labels) for n2,v2 in enumerate(labels) if (adj[n1,n2] > 0)} | |
self.G = nx.from_numpy_matrix(adj) | |
#Return the most central nodes. These have | |
def most_central(self,F=1,cent_type='betweenness'): | |
if cent_type == 'betweenness': | |
ranking = nx.betweenness_centrality(self.G).items() | |
elif cent_type == 'closeness': | |
ranking = nx.closeness_centrality(self.G).items() | |
elif cent_type == 'eigenvector': | |
ranking = nx.eigenvector_centrality(self.G).items() | |
elif cent_type == 'harmonic': | |
ranking = nx.harmonic_centrality(self.G).items() | |
elif cent_type == 'katz': | |
ranking = nx.katz_centrality(self.G).items() | |
elif cent_type == 'load': | |
ranking = nx.load_centrality(self.G).items() | |
elif cent_type == 'degree': | |
ranking = nx.degree_centrality(self.G).items() | |
ranks = [r for n,r in ranking] | |
cent_dict = dict([(self.lab[n],r) for n,r in ranking]) | |
m_centrality = sum(ranks) | |
if len(ranks) > 0: | |
m_centrality = m_centrality/len(ranks) | |
#Create a graph with the nodes above the cutoff centrality- remove the low centrality nodes | |
thresh = F*m_centrality | |
lab = {} | |
for k in self.lab: | |
lab[k] = self.lab[k] | |
g = Graph(self.adj.copy(),self.char_list) | |
for n,r in ranking: | |
if r < thresh: | |
g.G.remove_node(n) | |
del g.lab[n] | |
return (cent_dict,thresh,g) | |
#Displays the graph visualization | |
def show(self,path=None): | |
pos = nx.spring_layout(self.G) | |
nx.draw_networkx_nodes(self.G,pos,alpha=0.5) | |
edge_weights = dict([((u,v),int(d['weight'])) for u,v,d in self.G.edges(data=True)]) | |
nx.draw_networkx_labels(self.G,pos,self.lab,alpha=0.5) | |
nx.draw_networkx_edges(self.G,pos,alpha=0.5) | |
nx.draw_networkx_edge_labels(self.G,pos,edge_labels=self.edge_labels,alpha=1) | |
if path is None: | |
plt.show() | |
else: | |
plt.savefig(path) | |
if __name__ == '__main__': | |
adj = np.array([[0,1,1],[1,0,0],[0,0,1]]) | |
node_labels = ['v1','v2','v3'] | |
g = Graph(adj,node_labels) | |
#g.show(path=None) # set path to None to display | |
g.show(path='mygraph.png') # set path to save |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment