Created
May 6, 2019 14:50
-
-
Save adsieg/5fa6215cb4e857509a043637861f635a to your computer and use it in GitHub Desktop.
Jensen Shannon distance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def jensen_shannon(query, matrix): | |
""" | |
This function implements a Jensen-Shannon similarity | |
between the input query (an LDA topic distribution for a document) | |
and the entire corpus of topic distributions. | |
It returns an array of length M where M is the number of documents in the corpus | |
""" | |
# lets keep with the p,q notation above | |
p = query[None,:].T # take transpose | |
q = matrix.T # transpose matrix | |
m = 0.5*(p + q) | |
return np.sqrt(0.5*(entropy(p,m) + entropy(q,m))) | |
def get_most_similar_documents(query,matrix,k=10): | |
""" | |
This function implements the Jensen-Shannon distance above | |
and retruns the top k indices of the smallest jensen shannon distances | |
""" | |
sims = jensen_shannon(query,matrix) # list of jensen shannon distances | |
return sims.argsort()[:k] # the top k positional index of the smallest Jensen Shannon distances |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment