Created
March 31, 2024 17:57
-
-
Save wiseodd/a25eca90b00370a302c4b72232e03d39 to your computer and use it in GitHub Desktop.
Beam Search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import tqdm | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = tokenizer.eos_token | |
model = GPT2LMHeadModel.from_pretrained('gpt2').to(torch.bfloat16) | |
assert next(model.parameters()).dtype == torch.bfloat16 | |
input_text = 'Earth is' | |
MAX_DEPTH = 10 | |
BEAM_WIDTH = 3 | |
text_cands = [input_text] * BEAM_WIDTH | |
logprob_cands = torch.zeros(BEAM_WIDTH) | |
with torch.no_grad(): | |
encoded_input = tokenizer(input_text, return_tensors='pt') | |
logprobs = model(**encoded_input).logits.log_softmax(dim=-1)[0, -1, :] | |
topk_logprobs, topk_idxs = torch.topk(logprobs, k=BEAM_WIDTH) | |
topk_tokens = tokenizer.batch_decode(topk_idxs) | |
text_cands = [txt + tok for txt, tok in zip(text_cands, topk_tokens)] | |
logprob_cands += topk_logprobs | |
for d in tqdm.trange(MAX_DEPTH-1): | |
with torch.no_grad(): | |
encoded_inputs = tokenizer(text_cands, return_tensors='pt', padding=True) | |
# (beam_width, n_vocabs) | |
logprobs = model(**encoded_inputs).logits.log_softmax(dim=-1)[:, -1, :] | |
# (beam_width, n_vocabs) | |
total_logprobs = logprob_cands.unsqueeze(1).expand(BEAM_WIDTH, logprobs.shape[-1]) + logprobs | |
# Each (beam_width,) | |
topk_total_logprobs, topk_idxs_flat = torch.topk(total_logprobs.flatten(), k=BEAM_WIDTH) | |
# Convert flattened indices to the real indices | |
topk_idxs = topk_idxs_flat % len(tokenizer) | |
topk_tokens = tokenizer.batch_decode(topk_idxs) | |
# Infer which text_cand each topk_token is the continuation of | |
beam_idxs = topk_idxs_flat // len(tokenizer) | |
# Append topk_tokens to the respective sequences | |
text_cands = [text_cands[idx] + tok for idx, tok in zip(beam_idxs, topk_tokens)] | |
logprob_cands = topk_total_logprobs | |
print() | |
print(f'Context: {input_text}') | |
print('Results:') | |
for score, cand in zip(logprob_cands, text_cands): | |
print(f'{score:.2f}: {cand}') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment