Skip to content

Instantly share code, notes, and snippets.

@wiseodd
Created March 31, 2024 17:57
Show Gist options
  • Save wiseodd/a25eca90b00370a302c4b72232e03d39 to your computer and use it in GitHub Desktop.
Save wiseodd/a25eca90b00370a302c4b72232e03d39 to your computer and use it in GitHub Desktop.
Beam Search
import torch
import tqdm
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2').to(torch.bfloat16)
assert next(model.parameters()).dtype == torch.bfloat16
input_text = 'Earth is'
MAX_DEPTH = 10
BEAM_WIDTH = 3
text_cands = [input_text] * BEAM_WIDTH
logprob_cands = torch.zeros(BEAM_WIDTH)
with torch.no_grad():
encoded_input = tokenizer(input_text, return_tensors='pt')
logprobs = model(**encoded_input).logits.log_softmax(dim=-1)[0, -1, :]
topk_logprobs, topk_idxs = torch.topk(logprobs, k=BEAM_WIDTH)
topk_tokens = tokenizer.batch_decode(topk_idxs)
text_cands = [txt + tok for txt, tok in zip(text_cands, topk_tokens)]
logprob_cands += topk_logprobs
for d in tqdm.trange(MAX_DEPTH-1):
with torch.no_grad():
encoded_inputs = tokenizer(text_cands, return_tensors='pt', padding=True)
# (beam_width, n_vocabs)
logprobs = model(**encoded_inputs).logits.log_softmax(dim=-1)[:, -1, :]
# (beam_width, n_vocabs)
total_logprobs = logprob_cands.unsqueeze(1).expand(BEAM_WIDTH, logprobs.shape[-1]) + logprobs
# Each (beam_width,)
topk_total_logprobs, topk_idxs_flat = torch.topk(total_logprobs.flatten(), k=BEAM_WIDTH)
# Convert flattened indices to the real indices
topk_idxs = topk_idxs_flat % len(tokenizer)
topk_tokens = tokenizer.batch_decode(topk_idxs)
# Infer which text_cand each topk_token is the continuation of
beam_idxs = topk_idxs_flat // len(tokenizer)
# Append topk_tokens to the respective sequences
text_cands = [text_cands[idx] + tok for idx, tok in zip(beam_idxs, topk_tokens)]
logprob_cands = topk_total_logprobs
print()
print(f'Context: {input_text}')
print('Results:')
for score, cand in zip(logprob_cands, text_cands):
print(f'{score:.2f}: {cand}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment