Last active
December 5, 2017 23:11
-
-
Save SnowMasaya/6a50457a0b5dab7589612398168f7b56 to your computer and use it in GitHub Desktop.
高速かつ高性能な分散表現Gloveについて(PyTorch実装) ref: https://qiita.com/GushiSnow/items/e92ac2fea4f8448491ba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
J(\theta) = \frac{1}{2}\sum^{W}_{i,j=1}f(P_{ij})(u^{T}_{i}v_j - \log{P_(ij)})^2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
x_ij = X_ik[(w_i, w_j)] | |
except: | |
x_ij = 1 | |
x_max = 100 | |
alpha = 0.75 | |
if x_ij < x_max: | |
result = (x_ij / x_max) ** alpha | |
else: | |
result = 1 | |
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, center_words, target_words, coocs, weights): | |
center_embeds = self.embedding_v(center_words) | |
target_embeds = self.embedding_u(target_words) | |
# Reference(squeeze) | |
# http://pytorch.org/docs/master/torch.html#torch.squeeze | |
center_bias = self.v_bias(center_words).squeeze(1) | |
target_bias = self.u_bias(target_words).squeeze(1) | |
inner_product = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # noqa | |
loss = weights * torch.pow(inner_product + center_bias + target_bias - coocs, 2) # noqa | |
return torch.sum(loss) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x | |
-0.9697 0.1701 -0.5611 | |
0.0019 -0.1810 0.1066 | |
[torch.FloatTensor of size 2x3] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
torch.cat(x) | |
-0.9697 | |
0.1701 | |
-0.5611 | |
0.0019 | |
-0.1810 | |
0.1066 | |
[torch.FloatTensor of size 6] | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for epoch in range(self.epoch): | |
for i, batch in enumerate(get_batch(batch_size=self.batch_size, | |
train_data=train_data)): | |
inputs, targets, coocs, weights = zip(*batch) | |
inputs = torch.cat(inputs) | |
targets = torch.cat(targets) | |
coocs = torch.cat(coocs) | |
weights = torch.cat(weights) | |
self.model.zero_grad() | |
loss = self.model(inputs, targets, coocs, weights) | |
loss.backward() | |
self.optimizer.step() | |
losses.append(loss.data.tolist()[0]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_file(file_name: str): | |
with codecs.open(file_name, 'r', encoding='utf-8', errors='ignore') as f: | |
read_data = f.read().split('\n') | |
read_data = list(map(methodcaller("split", " "), read_data)) | |
return read_data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[['"', 'I', 'thought', 'so', '.'], ['All', 'right', ';', 'take', 'a', 'seat', '.'], ['Supper', '?--', 'you', 'want', 'supper', '?'], ['Supper', "'", 'll', 'be', 'ready', 'directly', '."']] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{'': 0, '</s>': 1, '、': 2, '。': 3, 'が': 4} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{0: '', 1: '</s>', 2: '、', 3: '。', 4: 'が'} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __make_word2index(self, vocab: list=[]): | |
word2index = {} | |
for vo in vocab: | |
if vo not in word2index.keys(): | |
word2index[vo] = len(word2index) | |
index2word = {v: k for k, v in word2index.items()} | |
word2index = dict(collections.OrderedDict(sorted(word2index.items(), | |
key=lambda t: t[1]))) | |
index2word = dict(collections.OrderedDict(sorted(index2word.items(), | |
key=lambda t: t[0]))) | |
return word2index, index2word |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __make_window_data(self, window_size: int=5, | |
corpus: list=[]): | |
windows = flatten([list(nltk.ngrams(['<DUMMY>'] * window_size + c + | |
['<DUMMY>'] * window_size, | |
window_size*2+1)) for c in corpus]) | |
window_data = [] | |
for window in windows: | |
for i in range(window_size*2 + 1): | |
if i == window_size or window[i] == '<DUMMY>': | |
continue | |
window_data.append((window[window_size], window[i])) | |
return window_dat |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
>>> list(itertools.combinations_with_replacement(A, 3)) | |
[('a', 'a', 'a'), | |
('a', 'a', 'b'), | |
('a', 'a', 'c'), | |
('a', 'b', 'b'), | |
('a', 'b', 'c'), | |
('a', 'c', 'c'), | |
('b', 'b', 'b'), | |
('b', 'b', 'c'), | |
('b', 'c', 'c'), | |
('c', 'c', 'c')] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __make_co_occurence_matrix(self, | |
window_data: list=[], | |
vocab: list=[]): | |
X_ik_window_5 = Counter(window_data) | |
X_ik = {} | |
weightinhg_dict = {} | |
for bigram in combinations_with_replacement(vocab, 2): | |
if bigram in X_ik_window_5.keys(): | |
co_occer = X_ik_window_5[bigram] | |
X_ik[bigram] = co_occer + 1 | |
X_ik[bigram[1], bigram[0]] = co_occer + 1 | |
else: | |
pass | |
weightinhg_dict[bigram] = self.__weighting(X_ik=X_ik, | |
w_i=bigram[0], | |
w_j=bigram[1]) | |
weightinhg_dict[bigram[1], bigram[0]] = \ | |
self.__weighting(X_ik=X_ik, w_i=bigram[1], w_j=bigram[0]) | |
weightinhg_dict = dict(collections.OrderedDict( | |
sorted(weightinhg_dict.items(), key=lambda t: t[1]))) | |
return X_ik, weightinhg_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment