Last active
November 2, 2017 21:45
-
-
Save norouzi/8c4d244922fa052fa8ec18d8af52d366 to your computer and use it in GitHub Desktop.
Reward Augmented Maximum Likelihood (RAML; https://arxiv.org/pdf/1609.00150.pdf) -- Python code snippet to compute marginal distribution of different #edits for a given sequence length, temperature, and vocab size.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scipy.misc as misc | |
import numpy as np | |
len_target = 20 | |
v = 60 # Vocabulary size | |
T = .9 # Temperature | |
max_edits = len_target | |
x = np.zeros(max_edits) | |
for n_edits in range(max_edits): | |
total_n_edits = 0 # total edits with n_edits edits without v^n_edits term | |
for n_substitutes in range(min(len_target, n_edits)+1): | |
print n_substitutes | |
n_insert = n_edits - n_substitutes | |
current_edits = misc.comb(len_target, n_substitutes, exact=False) * \ | |
misc.comb(len_target+n_insert-n_substitutes, n_insert, exact=False) | |
total_n_edits += current_edits | |
x[n_edits] = np.log(total_n_edits) + n_edits * np.log(v) | |
# log(tot_edits * v^n_edits) | |
x[n_edits] = x[n_edits] -n_edits / T * np.log(v) -n_edits / T | |
# log(tot_edits * v^n_edits * exp(-n_edits / T) * v^(-n_edits / T)) | |
p = np.exp(x) | |
p /= np.sum(p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment