Last active
May 27, 2024 09:25
-
-
Save ShikiOkasaka/c87bdf5bcb996658f579b9a8bb23a6bb to your computer and use it in GitHub Desktop.
Hugging Face Transformersでかな漢字変換の実験
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# pip install transformers | |
# pip install fugashi | |
# pip install ipadic | |
# pip install unidic_lite | |
import torch | |
from transformers import BertForMaskedLM | |
from transformers import BertJapaneseTokenizer | |
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-v3') | |
model = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-v3') | |
def pick(candidates): | |
print('Q: ', candidates) | |
encoded_candidates = tokenizer(candidates) | |
transposed = list(zip(*encoded_candidates.input_ids)) | |
for mask_token_index, ids in enumerate(transposed): | |
if len(set(ids)) != 1: | |
break | |
ids = encoded_candidates.input_ids[0][:mask_token_index] | |
ids += (tokenizer.mask_token_id, tokenizer.sep_token_id) | |
inputs = { | |
'input_ids': torch.tensor(ids).unsqueeze(0) | |
} | |
logits = model(**inputs).logits | |
token_ids = list(transposed[mask_token_index]) | |
topk = torch.topk(logits[0, mask_token_index][token_ids], k=len(candidates)) | |
print(' ', topk.values.tolist()) | |
print(' ', topk.indices.tolist()) | |
return candidates[topk.indices[0]] | |
print('A: ', pick(('わたしの生き概論', 'わたしの生きが異論', 'わたしの生きがい論'))) | |
print('A: ', pick(('電車に乗って', '電車に載って'))) | |
print('A: ', pick(('新聞に乗って', '新聞に載って'))) | |
print('A: ', pick(('先生にあって間隙', '先生にあって観劇', '先生にあって感激'))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment