Created
June 16, 2019 12:38
-
-
Save Taekyoon/fae590ca1e3edb9709f23cc73ef168a2 to your computer and use it in GitHub Desktop.
Vocabulary for Deep NLP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import json | |
from typing import List | |
from collections import Counter | |
from pathlib import Path | |
class Vocabulary(object): | |
def __init__(self, | |
max_size=None, | |
min_freq=1, | |
unknown_token='<unk>', | |
padding_token='<pad>', | |
bos_token='<bos>', | |
eos_token='<eos>', | |
reserved_tokens=None): | |
self.max_size = max_size | |
self.min_freq = min_freq | |
self.vocab_size = 0 | |
self.unknown_token = unknown_token | |
self.padding_token = padding_token | |
self.bos_token = bos_token | |
self.eos_token = eos_token | |
self.reserved_tokens = reserved_tokens | |
self.word_frequency = None | |
self._word_to_idx = None | |
self._idx_to_word = None | |
def fit(self, tokenized_dataset: List) -> None: | |
if isinstance(tokenized_dataset[0], list): | |
linear_dataset = self._square_to_linear(tokenized_dataset) | |
else: | |
linear_dataset = tokenized_dataset | |
max_size = self.max_size | |
if self.word_frequency is None: | |
self.word_frequency = Counter(linear_dataset) | |
else: | |
self.word_frequency.update(linear_dataset) | |
filtered_word_frequency = self._filter_min_freq(self.word_frequency, self.min_freq) | |
if max_size is None or len(filtered_word_frequency) < max_size: | |
max_size = len(filtered_word_frequency) | |
most_common_word_freq = filtered_word_frequency.most_common(max_size) | |
self._create_word_dict() | |
for word, _ in most_common_word_freq: | |
self._idx_to_word.append(word) | |
self._word_to_idx[word] = len(self._idx_to_word) - 1 | |
return | |
def to_indices(self, tokens: List): | |
return self[tokens] | |
def to_tokens(self, indices: List): | |
to_reduce = False | |
if not isinstance(indices, (list, tuple)): | |
indices = [indices] | |
to_reduce = True | |
max_idx = len(self._idx_to_word) - 1 | |
tokens = [] | |
for idx in indices: | |
if not isinstance(idx, int) or idx > max_idx: | |
raise ValueError('Token index {} in the provided `indices` is invalid.'.format(idx)) | |
else: | |
tokens.append(self._idx_to_word[idx]) | |
return tokens[0] if to_reduce else tokens | |
def to_json(self, json_path: Path) -> None: | |
vocab_obj = dict() | |
vocab_obj['max_size'] = self.max_size | |
vocab_obj['min_freq'] = self.min_freq | |
vocab_obj['vocab_size'] = self.vocab_size | |
vocab_obj['unknown_token'] = self.unknown_token | |
vocab_obj['padding_token'] = self.padding_token | |
vocab_obj['bos_token'] = self.bos_token | |
vocab_obj['eos_token'] = self.eos_token | |
vocab_obj['reserved_tokens'] = self.reserved_tokens | |
vocab_obj['word_frequency'] = dict(self.word_frequency) | |
vocab_obj['word_to_idx'] = self._word_to_idx | |
vocab_obj['idx_to_word'] = self._idx_to_word | |
with open(json_path, 'w') as jsonfile: | |
json.dump(vocab_obj, jsonfile, indent=4) | |
return | |
def from_json(self, json_path: Path) -> None: | |
with open(json_path, 'r') as jsonfile: | |
vocab_obj = json.load(jsonfile) | |
self.max_size = vocab_obj['max_size'] | |
self.min_freq = vocab_obj['min_freq'] | |
self.vocab_size = vocab_obj['vocab_size'] | |
self.unknown_token = vocab_obj['unknown_token'] | |
self.padding_token = vocab_obj['padding_token'] | |
self.bos_token = vocab_obj['bos_token'] | |
self.eos_token = vocab_obj['eos_token'] | |
self.reserved_tokens = vocab_obj['reserved_tokens'] | |
self.word_frequency = Counter(vocab_obj['word_frequency']) | |
self._word_to_idx = vocab_obj['word_to_idx'] | |
self._idx_to_word = vocab_obj['idx_to_word'] | |
return | |
@property | |
def word_to_idx(self): | |
return self._word_to_idx | |
@property | |
def idx_to_word(self): | |
return self._idx_to_word | |
def __len__(self): | |
return len(self._idx_to_word) | |
def __getitem__(self, words): | |
if not isinstance(words, (list, tuple)): | |
return self._word_to_idx[words] if words in self._word_to_idx else self._word_to_idx[self.unknown_token] | |
else: | |
return [self._word_to_idx[w] if w in self._word_to_idx else self._word_to_idx[self.unknown_token] | |
for w in words] | |
def __eq__(self, other): | |
if not self.max_size == other.max_size: | |
return False | |
if not self.min_freq == other.min_freq: | |
return False | |
if not self.vocab_size == other.vocab_size: | |
return False | |
if not self.unknown_token == other.unknown_token: | |
return False | |
if not self.padding_token == other.padding_token: | |
return False | |
if not self.bos_token == other.bos_token: | |
return False | |
if not self.eos_token == other.eos_token: | |
return False | |
if not self.reserved_tokens == other.reserved_tokens: | |
return False | |
if not self.word_frequency == other.word_frequency: | |
return False | |
if not self._word_to_idx == other.word_to_idx: | |
return False | |
if not self._idx_to_word == other.idx_to_word: | |
return False | |
return True | |
def _create_word_dict(self) -> None: | |
self._word_to_idx = dict() | |
self._idx_to_word = list() | |
if self.padding_token is not None: | |
self._idx_to_word.append(self.padding_token) | |
self._word_to_idx[self.padding_token] = len(self._idx_to_word) - 1 | |
if self.unknown_token is not None: | |
self._idx_to_word.append(self.unknown_token) | |
self._word_to_idx[self.unknown_token] = len(self._idx_to_word) - 1 | |
if self.bos_token is not None: | |
self._idx_to_word.append(self.bos_token) | |
self._word_to_idx[self.bos_token] = len(self._idx_to_word) - 1 | |
if self.eos_token is not None: | |
self._idx_to_word.append(self.eos_token) | |
self._word_to_idx[self.eos_token] = len(self._idx_to_word) - 1 | |
if self.reserved_tokens is not None: | |
for token in self.reserved_tokens: | |
self._idx_to_word.append(token) | |
self._word_to_idx[token] = len(self._idx_to_word) - 1 | |
return | |
@classmethod | |
def _filter_min_freq(cls, word_frequency: Counter, min_freq: int) -> Counter: | |
filtered_word_frequency = copy.deepcopy(word_frequency) | |
for word, freq in list(filtered_word_frequency.items()): | |
if freq < min_freq: | |
del filtered_word_frequency[word] | |
return filtered_word_frequency | |
@classmethod | |
def _square_to_linear(cls, squared_list: List) -> List: | |
return [word for sequence in squared_list for word in sequence] | |
## Test cases | |
def test_create_vocabulary(): | |
dummy_inputs = [['나는', '한국에', '살고', '있어요'], | |
['한국에', '사는건', '쉽지', '않아요'], | |
['학교종이', '울리면', '모여야', '해요'], | |
['학교종이', '울리지', '않으면', '어디로', '가야', '하죠']] | |
answer_vocab_size = 20 | |
vocab = Vocabulary() | |
vocab.fit(dummy_inputs) | |
sampled_tokens = dummy_inputs[0] | |
indices = vocab.to_indices(sampled_tokens) | |
reversed_tokens = vocab.to_tokens(indices) | |
assert isinstance(indices, list) | |
assert isinstance(indices[0], int) | |
assert len(vocab) == answer_vocab_size | |
assert reversed_tokens == sampled_tokens | |
def test_unknown_token_to_index(): | |
dummy_inputs = [['나는', '한국에', '살고', '있어요'], | |
['한국에', '사는건', '쉽지', '않아요'], | |
['학교종이', '울리면', '모여야', '해요'], | |
['학교종이', '울리지', '않으면', '어디로', '가야', '하죠']] | |
unknown_token = '기러기' | |
unknown_tokens = ['한쿡', '기러기', '정말', '많다'] | |
unknown_index = 1 | |
unknown_indices = [1, 1, 1, 1] | |
vocab = Vocabulary() | |
vocab.fit(dummy_inputs) | |
assert unknown_index == vocab.to_indices(unknown_token) | |
assert unknown_indices == vocab.to_indices(unknown_tokens) | |
def test_vocab_obj_as_json(): | |
json_path = Path('data/test/test_dataset/vocab_test.json') | |
dummy_inputs = [['나는', '한국에', '살고', '있어요'], | |
['한국에', '사는건', '쉽지', '않아요'], | |
['학교종이', '울리면', '모여야', '해요'], | |
['학교종이', '울리지', '않으면', '어디로', '가야', '하죠']] | |
vocab, dummy_vocab = Vocabulary(), Vocabulary() | |
vocab.fit(dummy_inputs) | |
vocab.to_json(json_path) | |
dummy_vocab.from_json(json_path) | |
assert vocab == dummy_vocab |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment