Created
December 20, 2022 03:07
-
-
Save fzyzcjy/fab4bf82c62f23b3432123c84f14a2c6 to your computer and use it in GitHub Desktop.
Speed up HuggingFace beam search by 10x
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
from collections import UserDict, defaultdict | |
from typing import Optional, Tuple, Any | |
import torch | |
from transformers import BeamScorer, BeamSearchScorer | |
from transformers.generation import BeamHypotheses | |
from ...utils.torch_utils import first_several_nonzero_indices | |
class MyBeamSearchScorer(BeamScorer): | |
def __init__( | |
self, | |
batch_size: int, | |
num_beams: int, | |
device: torch.device, | |
length_penalty: Optional[float] = 1.0, | |
do_early_stopping: Optional[bool] = False, | |
num_beam_hyps_to_keep: Optional[int] = 1, | |
num_beam_groups: Optional[int] = 1, | |
**kwargs, | |
): | |
self.num_beams = num_beams | |
self.device = device | |
self.length_penalty = length_penalty | |
self.do_early_stopping = do_early_stopping | |
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | |
self.num_beam_groups = num_beam_groups | |
self.group_size = self.num_beams // self.num_beam_groups | |
self._is_init = False | |
self._beam_hyps = [ | |
BeamHypotheses( | |
num_beams=self.num_beams, | |
length_penalty=self.length_penalty, | |
early_stopping=self.do_early_stopping, | |
) | |
for _ in range(batch_size) | |
] | |
self._done: torch.Tensor = \ | |
torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) | |
if not isinstance(num_beams, int) or num_beams <= 1: | |
raise ValueError( | |
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," | |
" one should make use of `greedy_search` instead." | |
) | |
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): | |
raise ValueError( | |
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" | |
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." | |
) | |
if "max_length" in kwargs: | |
warnings.warn( | |
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " | |
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | |
", or `group_beam_search(...)`." | |
) | |
self.t_dict = defaultdict(lambda: 0.0) | |
@property | |
def is_done(self) -> bool: | |
return self._done.all() | |
def process( | |
self, | |
input_ids: torch.LongTensor, | |
next_scores: torch.FloatTensor, | |
next_tokens: torch.LongTensor, | |
next_indices: torch.LongTensor, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
beam_indices: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor]: | |
# t0 = default_timer() | |
cur_len = input_ids.shape[-1] | |
batch_size = len(self._beam_hyps) | |
if not (batch_size == (input_ids.shape[0] // self.group_size)): | |
if self.num_beam_groups > 1: | |
raise ValueError( | |
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " | |
f"size of {self.group_size} is expected by the beam scorer." | |
) | |
else: | |
raise ValueError( | |
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " | |
f"{self.group_size} is expected by the beam scorer." | |
) | |
device = input_ids.device | |
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) | |
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) | |
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) | |
batch_beam_indices = torch.arange(batch_size, device=device)[:, None] * self.group_size + next_indices | |
# self.t_dict['prepare'] += default_timer() - t0 | |
# t0 = default_timer() | |
# for eos | |
is_eos_and_non_done = (~self._done[:, None]) & (next_tokens == (eos_token_id or -42)) | |
# self.t_dict['is_eos_and_non_done_sum'] += is_eos_and_non_done.sum().cpu().item() | |
# self.t_dict['is_eos_and_non_done_nonzero'] += int(is_eos_and_non_done.sum().cpu().item() > 0) | |
# self.t_dict['is_eos_and_non_done_count'] += 1 | |
# self.t_dict['for-eos-create-a'] += default_timer() - t0 | |
# t0 = default_timer() | |
next_indices_selected = next_indices[is_eos_and_non_done] | |
# self.t_dict['for-eos-create-b'] += default_timer() - t0 | |
# t0 = default_timer() | |
next_scores_selected = next_scores[is_eos_and_non_done] | |
# self.t_dict['for-eos-create-c'] += default_timer() - t0 | |
# t0 = default_timer() | |
is_eos_and_non_done_indices = is_eos_and_non_done.nonzero() | |
# self.t_dict['for-eos-create-d'] += default_timer() - t0 | |
# t0 = default_timer() | |
next_indices_selected = next_indices_selected.cpu().numpy() | |
next_scores_selected = next_scores_selected.cpu().numpy() | |
is_eos_and_non_done_indices = is_eos_and_non_done_indices.cpu().numpy() | |
# self.t_dict['for-eos-to-cpu'] += default_timer() - t0 | |
# t0 = default_timer() | |
for i, (batch_idx, beam_token_rank) in enumerate(is_eos_and_non_done_indices): | |
batch_beam_idx = batch_idx * self.group_size + next_indices_selected[i] | |
# if beam_token does not belong to top num_beams tokens, it should not be added | |
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size | |
if is_beam_token_worse_than_top_num_beams: | |
continue | |
if beam_indices is not None: | |
beam_index = beam_indices[batch_beam_idx] | |
beam_index = beam_index + (batch_beam_idx,) | |
else: | |
beam_index = None | |
self._beam_hyps[batch_idx].add( | |
input_ids[batch_beam_idx].clone(), | |
next_scores_selected[i].item(), | |
beam_indices=beam_index, | |
) | |
# self.t_dict['for-eos-loop'] += default_timer() - t0 | |
# t0 = default_timer() | |
# for non-eos | |
first_several_non_eos = first_several_nonzero_indices( | |
(next_tokens != (eos_token_id or -42)).int(), batch_enable=~self._done, k=self.num_beams) | |
next_beam_scores[:] = next_scores[first_several_non_eos].reshape((-1, self.num_beams)) | |
next_beam_tokens[:] = next_tokens[first_several_non_eos].reshape((-1, self.num_beams)) | |
next_beam_indices[:] = batch_beam_indices[first_several_non_eos].reshape((-1, self.num_beams)) | |
# self.t_dict['for-non-eos'] += default_timer() - t0 | |
# t0 = default_timer() | |
# those who are `done` | |
next_beam_scores[self._done, :] = 0 | |
if pad_token_id is not None: | |
next_beam_tokens[self._done, :] = pad_token_id | |
next_beam_indices[self._done, :] = 0 | |
# Check if we are done so that we can save a pad step if all(done) | |
next_scores_max = next_scores.max(dim=1)[0].cpu().numpy() | |
self._done |= torch.tensor([ | |
beam_hyp.is_done(next_scores_max[batch_idx], cur_len) | |
for batch_idx, beam_hyp in enumerate(self._beam_hyps) | |
], device=device) | |
# self.t_dict['done-related'] += default_timer() - t0 | |
return UserDict( | |
{ | |
"next_beam_scores": next_beam_scores.view(-1), | |
"next_beam_tokens": next_beam_tokens.view(-1), | |
"next_beam_indices": next_beam_indices.view(-1), | |
} | |
) | |
def finalize( | |
self, | |
input_ids: torch.LongTensor, | |
final_beam_scores: torch.FloatTensor, | |
final_beam_tokens: torch.LongTensor, | |
final_beam_indices: torch.LongTensor, | |
max_length: int, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
beam_indices: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.LongTensor]: | |
batch_size = len(self._beam_hyps) | |
# finalize all open beam hypotheses and add to generated hypotheses | |
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |
if self._done[batch_idx]: | |
continue | |
# all open beam hypotheses are added to the beam hypothesis | |
# beam hypothesis class automatically keeps the best beams | |
for beam_id in range(self.num_beams): | |
batch_beam_idx = batch_idx * self.num_beams + beam_id | |
final_score = final_beam_scores[batch_beam_idx].item() | |
final_tokens = input_ids[batch_beam_idx] | |
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None | |
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) | |
# select the best hypotheses | |
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) | |
best = [] | |
best_indices = [] | |
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) | |
# retrieve best hypotheses | |
for i, beam_hyp in enumerate(self._beam_hyps): | |
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) | |
for j in range(self.num_beam_hyps_to_keep): | |
best_hyp_tuple = sorted_hyps.pop() | |
best_score = best_hyp_tuple[0] | |
best_hyp = best_hyp_tuple[1] | |
best_index = best_hyp_tuple[2] | |
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) | |
# append hyp to lists | |
best.append(best_hyp) | |
# append indices to list | |
best_indices.append(best_index) | |
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score | |
# prepare for adding eos | |
sent_lengths_max = sent_lengths.max().item() + 1 | |
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max | |
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |
if len(best_indices) > 0 and best_indices[0] is not None: | |
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |
else: | |
indices = None | |
# shorter batches are padded if needed | |
if sent_lengths.min().item() != sent_lengths.max().item(): | |
assert pad_token_id is not None, "`pad_token_id` has to be defined" | |
decoded.fill_(pad_token_id) | |
if indices is not None: | |
indices.fill_(-1) | |
# fill with hypotheses and eos_token_id if the latter fits in | |
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): | |
decoded[i, : sent_lengths[i]] = hypo | |
if indices is not None: | |
indices[i, : len(best_idx)] = torch.tensor(best_idx) | |
if sent_lengths[i] < sent_max_len: | |
decoded[i, sent_lengths[i]] = eos_token_id | |
return UserDict( | |
{ | |
"sequences": decoded, | |
"sequence_scores": best_scores, | |
"beam_indices": indices, | |
} | |
) | |
class BeamSearchScorerForComparison: | |
def __init__(self, **kwargs: Any): | |
self.ours = MyBeamSearchScorer(**kwargs) | |
self.theirs = BeamSearchScorer(**kwargs) | |
@property | |
def _beam_hyps(self): | |
assert len(self.ours._beam_hyps) == len(self.theirs._beam_hyps) | |
return self.ours._beam_hyps | |
@property | |
def num_beams(self): | |
assert self.ours.num_beams == self.theirs.num_beams | |
return self.ours.num_beams | |
@property | |
def is_done(self): | |
assert self.ours.is_done == self.theirs.is_done | |
return self.ours.is_done | |
def process(self, *args: Any, **kwargs: Any): | |
ours_output = self.ours.process(*args, **kwargs) | |
theirs_output = self.theirs.process(*args, **kwargs) | |
assert isinstance(ours_output, UserDict) and isinstance(theirs_output, UserDict) | |
self._check_output_equality(ours_output, theirs_output) | |
self._check_state_equality() | |
return theirs_output | |
def finalize(self, *args: Any, **kwargs: Any): | |
ours_output = self.ours.finalize(*args, **kwargs) | |
theirs_output = self.theirs.finalize(*args, **kwargs) | |
assert isinstance(ours_output, UserDict) and isinstance(theirs_output, UserDict) | |
self._check_output_equality(ours_output, theirs_output) | |
self._check_state_equality() | |
return theirs_output | |
@staticmethod | |
def _check_output_equality(ours_output: UserDict, theirs_output: UserDict): | |
assert set(ours_output.keys()) == set(theirs_output.keys()) | |
for k in ours_output.keys(): | |
assert BeamSearchScorerForComparison._tensor_equals(ours_output[k], theirs_output[k]), \ | |
f'output not equal. key={k} ' \ | |
f'ours_output={ours_output[k]} theirs_output={theirs_output[k]} ' | |
def _check_state_equality(self): | |
assert self.ours.is_done == self.theirs.is_done | |
@staticmethod | |
def _tensor_equals(a: Optional[torch.Tensor], b: Optional[torch.Tensor]): | |
return (a is None and b is None) or torch.allclose(a, b) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import warnings | |
from collections import defaultdict, UserDict | |
from timeit import default_timer | |
from typing import Optional, Iterable, Union, Callable, List, Tuple | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from torch import nn | |
from torch.distributions.constraints import Constraint | |
from transformers import LogitsProcessorList, StoppingCriteriaList, ConstrainedBeamSearchScorer, DisjunctiveConstraint, \ | |
PhrasalConstraint, BeamScorer | |
from transformers.generation import validate_stopping_criteria, BeamHypotheses | |
from transformers.generation.utils import GenerateOutput, GenerationMixin, BeamSearchOutput, \ | |
BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, SampleOutput, SampleEncoderDecoderOutput, \ | |
SampleDecoderOnlyOutput | |
# copied and modified from: huggingface utils.py | |
from transformers.pytorch_utils import torch_int_div | |
from ...utils.huggingface.beam_search import MyBeamSearchScorer | |
class MyBeamSearchScorerViaNumpy(BeamScorer): | |
def __init__( | |
self, | |
batch_size: int, | |
num_beams: int, | |
device: torch.device, | |
length_penalty: Optional[float] = 1.0, | |
do_early_stopping: Optional[bool] = False, | |
num_beam_hyps_to_keep: Optional[int] = 1, | |
num_beam_groups: Optional[int] = 1, | |
**kwargs, | |
): | |
self.num_beams = num_beams | |
self.device = device | |
self.length_penalty = length_penalty | |
self.do_early_stopping = do_early_stopping | |
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | |
self.num_beam_groups = num_beam_groups | |
self.group_size = self.num_beams // self.num_beam_groups | |
self._is_init = False | |
self._beam_hyps = [ | |
BeamHypotheses( | |
num_beams=self.num_beams, | |
length_penalty=self.length_penalty, | |
early_stopping=self.do_early_stopping, | |
) | |
for _ in range(batch_size) | |
] | |
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) | |
self.t_dict = defaultdict(lambda: 0.0) | |
if not isinstance(num_beams, int) or num_beams <= 1: | |
raise ValueError( | |
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," | |
" one should make use of `greedy_search` instead." | |
) | |
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): | |
raise ValueError( | |
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" | |
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." | |
) | |
if "max_length" in kwargs: | |
warnings.warn( | |
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " | |
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | |
", or `group_beam_search(...)`." | |
) | |
@property | |
def is_done(self) -> bool: | |
return self._done.all() | |
def process( | |
self, | |
input_ids: torch.LongTensor, | |
next_scores: torch.FloatTensor, | |
next_tokens: torch.LongTensor, | |
next_indices: torch.LongTensor, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
beam_indices: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor]: | |
t0 = default_timer() | |
device = input_ids.device | |
input_ids = input_ids.cpu().numpy() | |
next_scores = next_scores.cpu().numpy() | |
next_tokens = next_tokens.cpu().numpy() | |
next_indices = next_indices.cpu().numpy() | |
self.t_dict['move_to_numpy'] += default_timer() - t0 | |
t0 = default_timer() | |
cur_len = input_ids.shape[-1] | |
batch_size = len(self._beam_hyps) | |
if not (batch_size == (input_ids.shape[0] // self.group_size)): | |
if self.num_beam_groups > 1: | |
raise ValueError( | |
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " | |
f"size of {self.group_size} is expected by the beam scorer." | |
) | |
else: | |
raise ValueError( | |
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " | |
f"{self.group_size} is expected by the beam scorer." | |
) | |
# next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) | |
# next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) | |
# next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) | |
next_beam_scores = np.zeros((batch_size, self.group_size), dtype=next_scores.dtype) | |
next_beam_tokens = np.zeros((batch_size, self.group_size), dtype=next_tokens.dtype) | |
next_beam_indices = np.zeros((batch_size, self.group_size), dtype=next_indices.dtype) | |
self.t_dict['create_array'] += default_timer() - t0 | |
t0 = default_timer() | |
# print( | |
# f'beamscorer.process: ' | |
# f'beam_hyps={len(self._beam_hyps)} ' | |
# f'next_scores={next_scores.shape} ' | |
# f'next_tokens={next_tokens.shape} ' | |
# f'next_indices={next_indices.shape} ' | |
# f'pad_token_id={pad_token_id} ' | |
# f'eos_token_id={eos_token_id} ' | |
# ) | |
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |
# t0 = default_timer() | |
if self._done[batch_idx]: | |
if self.num_beams < len(beam_hyp): | |
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") | |
if eos_token_id is None or pad_token_id is None: | |
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") | |
# pad the batch | |
next_beam_scores[batch_idx, :] = 0 | |
next_beam_tokens[batch_idx, :] = pad_token_id | |
next_beam_indices[batch_idx, :] = 0 | |
continue | |
# self.t_dict['if_done'] += default_timer() - t0 | |
# t0 = default_timer() | |
# next tokens for this sentence | |
beam_idx = 0 | |
for beam_token_rank, (next_token, next_score, next_index) in enumerate( | |
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) | |
): | |
batch_beam_idx = batch_idx * self.group_size + next_index | |
# add to generated hypotheses if end of sentence | |
if (eos_token_id is not None) and (next_token.item() == eos_token_id): | |
# if beam_token does not belong to top num_beams tokens, it should not be added | |
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size | |
if is_beam_token_worse_than_top_num_beams: | |
continue | |
if beam_indices is not None: | |
beam_index = beam_indices[batch_beam_idx] | |
beam_index = beam_index + (batch_beam_idx,) | |
else: | |
beam_index = None | |
beam_hyp.add( | |
input_ids[batch_beam_idx].copy(), | |
# input_ids[batch_beam_idx].clone(), | |
next_score.item(), | |
beam_indices=beam_index, | |
) | |
else: | |
# add next predicted token since it is not eos_token | |
next_beam_scores[batch_idx, beam_idx] = next_score | |
next_beam_tokens[batch_idx, beam_idx] = next_token | |
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx | |
beam_idx += 1 | |
# once the beam for next step is full, don't add more tokens to it. | |
if beam_idx == self.group_size: | |
break | |
# self.t_dict['inner_loop'] += default_timer() - t0 | |
# t0 = default_timer() | |
if beam_idx < self.group_size: | |
raise ValueError( | |
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:" | |
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." | |
) | |
# Check if we are done so that we can save a pad step if all(done) | |
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( | |
next_scores[batch_idx].max().item(), cur_len | |
) | |
# self.t_dict['postpare'] += default_timer() - t0 | |
self.t_dict['body'] += default_timer() - t0 | |
t0 = default_timer() | |
ans = UserDict( | |
{ | |
# "next_beam_scores": next_beam_scores.view(-1), | |
# "next_beam_tokens": next_beam_tokens.view(-1), | |
# "next_beam_indices": next_beam_indices.view(-1), | |
"next_beam_scores": torch.tensor(next_beam_scores, device=device).view(-1), | |
"next_beam_tokens": torch.tensor(next_beam_tokens, device=device).view(-1), | |
"next_beam_indices": torch.tensor(next_beam_indices, device=device).view(-1), | |
} | |
) | |
self.t_dict['ans'] += default_timer() - t0 | |
t0 = default_timer() | |
return ans | |
def finalize( | |
self, | |
input_ids: torch.LongTensor, | |
final_beam_scores: torch.FloatTensor, | |
final_beam_tokens: torch.LongTensor, | |
final_beam_indices: torch.LongTensor, | |
max_length: int, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
beam_indices: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.LongTensor]: | |
batch_size = len(self._beam_hyps) | |
device = input_ids.device | |
input_ids = input_ids.cpu().numpy() | |
final_beam_scores = final_beam_scores.cpu().numpy() | |
# final_beam_tokens = final_beam_tokens.cpu().numpy() | |
# final_beam_indices = final_beam_indices.cpu().numpy() | |
# finalize all open beam hypotheses and add to generated hypotheses | |
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |
if self._done[batch_idx]: | |
continue | |
# all open beam hypotheses are added to the beam hypothesis | |
# beam hypothesis class automatically keeps the best beams | |
for beam_id in range(self.num_beams): | |
batch_beam_idx = batch_idx * self.num_beams + beam_id | |
final_score = final_beam_scores[batch_beam_idx].item() | |
final_tokens = input_ids[batch_beam_idx] | |
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None | |
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) | |
# select the best hypotheses | |
# sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) | |
sent_lengths = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=int) | |
best = [] | |
best_indices = [] | |
# best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) | |
best_scores = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.float32) | |
# retrieve best hypotheses | |
for i, beam_hyp in enumerate(self._beam_hyps): | |
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) | |
for j in range(self.num_beam_hyps_to_keep): | |
best_hyp_tuple = sorted_hyps.pop() | |
best_score = best_hyp_tuple[0] | |
best_hyp = best_hyp_tuple[1] | |
best_index = best_hyp_tuple[2] | |
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) | |
# append hyp to lists | |
best.append(best_hyp) | |
# append indices to list | |
best_indices.append(best_index) | |
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score | |
# prepare for adding eos | |
sent_lengths_max = sent_lengths.max().item() + 1 | |
sent_max_len = int(min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max) | |
# decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |
decoded = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=int) | |
if len(best_indices) > 0 and best_indices[0] is not None: | |
# indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |
indices = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=int) | |
else: | |
indices = None | |
# shorter batches are padded if needed | |
if sent_lengths.min().item() != sent_lengths.max().item(): | |
assert pad_token_id is not None, "`pad_token_id` has to be defined" | |
decoded.fill(pad_token_id) | |
if indices is not None: | |
indices.fill(-1) | |
# fill with hypotheses and eos_token_id if the latter fits in | |
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): | |
decoded[i, : sent_lengths[i]] = hypo | |
if indices is not None: | |
indices[i, : len(best_idx)] = torch.tensor(best_idx) | |
if sent_lengths[i] < sent_max_len: | |
decoded[i, sent_lengths[i]] = eos_token_id | |
return UserDict( | |
{ | |
# "sequences": decoded, | |
# "sequence_scores": best_scores, | |
# "beam_indices": indices, | |
"sequences": torch.tensor(decoded, device=device), | |
"sequence_scores": torch.tensor(best_scores, device=device), | |
"beam_indices": torch.tensor(indices, device=device) if indices is not None else None, | |
} | |
) | |
# ChosenBeamSearchScorer = BeamSearchScorer | |
# ChosenBeamSearchScorer = BeamSearchScorerForComparison | |
ChosenBeamSearchScorer = MyBeamSearchScorer | |
# ChosenBeamSearchScorer = MyBeamSearchScorerViaNumpy | |
@torch.no_grad() | |
def my_generate( | |
self: GenerationMixin, | |
inputs: Optional[torch.Tensor] = None, | |
max_length: Optional[int] = None, | |
min_length: Optional[int] = None, | |
do_sample: Optional[bool] = None, | |
early_stopping: Optional[bool] = None, | |
num_beams: Optional[int] = None, | |
temperature: Optional[float] = None, | |
penalty_alpha: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
typical_p: Optional[float] = None, | |
repetition_penalty: Optional[float] = None, | |
bad_words_ids: Optional[Iterable[int]] = None, | |
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, | |
bos_token_id: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
length_penalty: Optional[float] = None, | |
no_repeat_ngram_size: Optional[int] = None, | |
encoder_no_repeat_ngram_size: Optional[int] = None, | |
num_return_sequences: Optional[int] = None, | |
max_time: Optional[float] = None, | |
max_new_tokens: Optional[int] = None, | |
decoder_start_token_id: Optional[int] = None, | |
use_cache: Optional[bool] = None, | |
num_beam_groups: Optional[int] = None, | |
diversity_penalty: Optional[float] = None, | |
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
renormalize_logits: Optional[bool] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
constraints: Optional[List[Constraint]] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
forced_bos_token_id: Optional[int] = None, | |
forced_eos_token_id: Optional[int] = None, | |
remove_invalid_values: Optional[bool] = None, | |
synced_gpus: Optional[bool] = False, | |
exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, | |
suppress_tokens: Optional[List[int]] = None, | |
begin_suppress_tokens: Optional[List[int]] = None, | |
forced_decoder_ids: Optional[List[List[int]]] = None, | |
**model_kwargs, | |
) -> Union[GenerateOutput, torch.LongTensor]: | |
print('my_generate called') | |
# 0. Validate the `.generate()` call | |
self._validate_model_class() | |
self._validate_model_kwargs(model_kwargs.copy()) | |
# 1. Set generation parameters if not already defined | |
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | |
num_beams = num_beams if num_beams is not None else self.config.num_beams | |
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty | |
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping | |
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups | |
do_sample = do_sample if do_sample is not None else self.config.do_sample | |
num_return_sequences = ( | |
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences | |
) | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
if eos_token_id is None and hasattr(self.config, "decoder"): | |
eos_token_id = self.config.decoder.eos_token_id | |
if pad_token_id is None and eos_token_id is not None: | |
if model_kwargs.get("attention_mask", None) is None: | |
logger.warning( | |
"The attention mask and the pad token id were not set. As a consequence, you may observe " | |
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." | |
) | |
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |
pad_token_id = eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# 2. Define model inputs | |
# inputs_tensor has to be defined | |
# model_input_name is defined if model-specific keyword input is passed | |
# otherwise model_input_name is None | |
# all model-specific keyword inputs are removed from `model_kwargs` | |
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) | |
batch_size = inputs_tensor.shape[0] | |
# 3. Define other model kwargs | |
model_kwargs["output_attentions"] = output_attentions | |
model_kwargs["output_hidden_states"] = output_hidden_states | |
model_kwargs["use_cache"] = use_cache | |
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) | |
requires_attention_mask = "encoder_outputs" not in model_kwargs | |
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: | |
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( | |
inputs_tensor, pad_token_id, eos_token_id | |
) | |
# decoder-only models should use left-padding for generation | |
if not self.config.is_encoder_decoder: | |
if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0: | |
logger.warning( | |
"A decoder-only architecture is being used, but right-padding was detected! For correct " | |
"generation results, please set `padding_side='left'` when initializing the tokenizer." | |
) | |
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: | |
# if model is encoder decoder encoder_outputs are created | |
# and added to `model_kwargs` | |
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( | |
inputs_tensor, model_kwargs, model_input_name | |
) | |
# 4. Prepare `input_ids` which will be used for auto-regressive generation | |
if self.config.is_encoder_decoder: | |
input_ids = self._prepare_decoder_input_ids_for_generation( | |
batch_size, | |
decoder_start_token_id=decoder_start_token_id, | |
bos_token_id=bos_token_id, | |
model_kwargs=model_kwargs, | |
device=inputs_tensor.device, | |
) | |
else: | |
# if decoder-only then inputs_tensor has to be `input_ids` | |
input_ids = inputs_tensor | |
# 5. Prepare `max_length` depending on other stopping criteria. | |
input_ids_seq_length = input_ids.shape[-1] | |
if max_length is None and max_new_tokens is None: | |
warnings.warn( | |
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to " | |
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " | |
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " | |
"using `max_new_tokens` to control the maximum length of the generation.", | |
UserWarning, | |
) | |
elif max_length is None and max_new_tokens is not None: | |
max_length = max_new_tokens + input_ids_seq_length | |
elif max_length is not None and max_new_tokens is not None: | |
raise ValueError( | |
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" | |
" limit to the generated output length. Remove one of those arguments. Please refer to the" | |
" documentation for more information. " | |
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" | |
) | |
# default to config if still None | |
max_length = max_length if max_length is not None else self.config.max_length | |
min_length = min_length if min_length is not None else self.config.min_length | |
if min_length is not None and min_length > max_length: | |
raise ValueError( | |
f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum " | |
f"length ({max_length})" | |
) | |
if input_ids_seq_length >= max_length: | |
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" | |
logger.warning( | |
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" | |
f" {max_length}. This can lead to unexpected behavior. You should consider increasing " | |
"`max_new_tokens`." | |
) | |
# 6. determine generation mode | |
is_constraint_gen_mode = constraints is not None or force_words_ids is not None | |
is_contrastive_search_gen_mode = ( | |
top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 | |
) | |
is_greedy_gen_mode = ( | |
(num_beams == 1) | |
and (num_beam_groups == 1) | |
and do_sample is False | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_sample_gen_mode = ( | |
(num_beams == 1) | |
and (num_beam_groups == 1) | |
and do_sample is True | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_beam_gen_mode = ( | |
(num_beams > 1) | |
and (num_beam_groups == 1) | |
and do_sample is False | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_beam_sample_gen_mode = ( | |
(num_beams > 1) | |
and (num_beam_groups == 1) | |
and do_sample is True | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
is_group_beam_gen_mode = ( | |
(num_beams > 1) | |
and (num_beam_groups > 1) | |
and not is_constraint_gen_mode | |
and not is_contrastive_search_gen_mode | |
) | |
if num_beam_groups > num_beams: | |
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") | |
if is_group_beam_gen_mode and do_sample is True: | |
raise ValueError( | |
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." | |
) | |
if self.device.type != input_ids.device.type: | |
warnings.warn( | |
"You are calling .generate() with the `input_ids` being on a device type different" | |
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" | |
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." | |
" Please make sure that you have put `input_ids` to the" | |
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" | |
" running `.generate()`.", | |
UserWarning, | |
) | |
# 7. prepare distribution pre_processing samplers | |
logits_processor = self._get_logits_processor( | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, | |
input_ids_seq_length=input_ids_seq_length, | |
encoder_input_ids=inputs_tensor, | |
bad_words_ids=bad_words_ids, | |
min_length=min_length, | |
max_length=max_length, | |
eos_token_id=eos_token_id, | |
forced_bos_token_id=forced_bos_token_id, | |
forced_eos_token_id=forced_eos_token_id, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
num_beams=num_beams, | |
num_beam_groups=num_beam_groups, | |
diversity_penalty=diversity_penalty, | |
remove_invalid_values=remove_invalid_values, | |
exponential_decay_length_penalty=exponential_decay_length_penalty, | |
logits_processor=logits_processor, | |
renormalize_logits=renormalize_logits, | |
suppress_tokens=suppress_tokens, | |
begin_suppress_tokens=begin_suppress_tokens, | |
forced_decoder_ids=forced_decoder_ids, | |
) | |
# 8. prepare stopping criteria | |
stopping_criteria = self._get_stopping_criteria( | |
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria | |
) | |
# 9. go into different generation modes | |
if is_greedy_gen_mode: | |
if num_return_sequences > 1: | |
raise ValueError( | |
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." | |
) | |
# 10. run greedy search | |
return self.greedy_search( | |
input_ids, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_contrastive_search_gen_mode: | |
if num_return_sequences > 1: | |
raise ValueError( | |
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." | |
) | |
return self.contrastive_search( | |
input_ids, | |
top_k=top_k, | |
penalty_alpha=penalty_alpha, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_sample_gen_mode: | |
# 10. prepare logits warper | |
logits_warper = self._get_logits_warper( | |
top_k=top_k, | |
top_p=top_p, | |
typical_p=typical_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
renormalize_logits=renormalize_logits, | |
) | |
# 11. expand input_ids with `num_return_sequences` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# 12. run sample | |
return my_sample( | |
self, | |
input_ids, | |
logits_processor=logits_processor, | |
logits_warper=logits_warper, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_gen_mode: | |
if num_return_sequences > num_beams: | |
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
# 10. prepare beam search scorer | |
beam_scorer = ChosenBeamSearchScorer( | |
batch_size=batch_size, | |
num_beams=num_beams, | |
device=inputs_tensor.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
num_beam_hyps_to_keep=num_return_sequences, | |
) | |
# 11. interleave input_ids with `num_beams` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=num_beams, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# 12. run beam search | |
return my_beam_search( | |
self, | |
input_ids, | |
beam_scorer, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_beam_sample_gen_mode: | |
# 10. prepare logits warper | |
logits_warper = self._get_logits_warper( | |
top_k=top_k, | |
top_p=top_p, | |
typical_p=typical_p, | |
temperature=temperature, | |
num_beams=num_beams, | |
renormalize_logits=renormalize_logits, | |
) | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
# 11. prepare beam search scorer | |
beam_scorer = ChosenBeamSearchScorer( | |
batch_size=batch_size * num_return_sequences, | |
num_beams=num_beams, | |
device=inputs_tensor.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
) | |
# 12. interleave input_ids with `num_beams` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=num_beams * num_return_sequences, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# 13. run beam sample | |
return self.beam_sample( | |
input_ids, | |
beam_scorer, | |
logits_processor=logits_processor, | |
logits_warper=logits_warper, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_group_beam_gen_mode: | |
if num_return_sequences > num_beams: | |
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | |
if num_beams % num_beam_groups != 0: | |
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
if typical_p is not None: | |
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") | |
# 10. prepare beam search scorer | |
beam_scorer = MyBeamSearchScorerViaNumpy( | |
batch_size=batch_size, | |
num_beams=num_beams, | |
max_length=stopping_criteria.max_length, | |
device=inputs_tensor.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
num_beam_hyps_to_keep=num_return_sequences, | |
num_beam_groups=num_beam_groups, | |
) | |
# 11. interleave input_ids with `num_beams` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=num_beams, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# 12. run beam search | |
return self.group_beam_search( | |
input_ids, | |
beam_scorer, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
elif is_constraint_gen_mode: | |
if num_return_sequences > num_beams: | |
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") | |
if stopping_criteria.max_length is None: | |
raise ValueError("`max_length` needs to be a stopping_criteria for now.") | |
if num_beams <= 1: | |
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") | |
if do_sample: | |
raise ValueError("`do_sample` needs to be false for constrained generation.") | |
if num_beam_groups is not None and num_beam_groups > 1: | |
raise ValueError("`num_beam_groups` not supported yet for constrained generation.") | |
final_constraints = [] | |
if constraints is not None: | |
final_constraints = constraints | |
if force_words_ids is not None: | |
def typeerror(): | |
raise ValueError( | |
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" | |
f"of positive integers, but is {force_words_ids}." | |
) | |
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: | |
typeerror() | |
for word_ids in force_words_ids: | |
if isinstance(word_ids[0], list): | |
if not isinstance(word_ids, list) or len(word_ids) == 0: | |
typeerror() | |
if any(not isinstance(token_ids, list) for token_ids in word_ids): | |
typeerror() | |
if any( | |
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) | |
for token_ids in word_ids | |
): | |
typeerror() | |
constraint = DisjunctiveConstraint(word_ids) | |
else: | |
if not isinstance(word_ids, list) or len(word_ids) == 0: | |
typeerror() | |
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): | |
typeerror() | |
constraint = PhrasalConstraint(word_ids) | |
final_constraints.append(constraint) | |
# 10. prepare beam search scorer | |
constrained_beam_scorer = ConstrainedBeamSearchScorer( | |
constraints=final_constraints, | |
batch_size=batch_size, | |
num_beams=num_beams, | |
device=inputs_tensor.device, | |
length_penalty=length_penalty, | |
do_early_stopping=early_stopping, | |
num_beam_hyps_to_keep=num_return_sequences, | |
) | |
# 11. interleave input_ids with `num_beams` additional sequences per batch | |
input_ids, model_kwargs = self._expand_inputs_for_generation( | |
input_ids=input_ids, | |
expand_size=num_beams, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
**model_kwargs, | |
) | |
# 12. run beam search | |
return self.constrained_beam_search( | |
input_ids, | |
constrained_beam_scorer=constrained_beam_scorer, | |
logits_processor=logits_processor, | |
stopping_criteria=stopping_criteria, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
output_scores=output_scores, | |
return_dict_in_generate=return_dict_in_generate, | |
synced_gpus=synced_gpus, | |
**model_kwargs, | |
) | |
def my_sample( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
logits_warper: Optional[LogitsProcessorList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = False, | |
**model_kwargs, | |
) -> Union[SampleOutput, torch.LongTensor]: | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use" | |
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) | |
t_start = default_timer() | |
t_dict = defaultdict(lambda: 0.0) | |
this_peer_finished = False # used by synced_gpus only | |
# auto-regressive generation | |
while True: | |
t0 = default_timer() | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
t_dict['prepare_inputs'] += default_timer() - t0 | |
t0 = default_timer() | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
t_dict['model'] += default_timer() - t0 | |
t0 = default_timer() | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# sample | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
t_dict['sample'] += default_timer() - t0 | |
t0 = default_timer() | |
# finished sentences should have their next token be a padding token | |
if eos_token_id is not None: | |
if pad_token_id is None: | |
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
# if eos_token was found in one sentence, set sentence to finished | |
if eos_token_id is not None: | |
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) | |
# stop when each sentence is finished, or if we exceed the maximum length | |
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
t_dict['misc'] += default_timer() - t0 | |
print( | |
f'sample ' | |
f't_total={default_timer() - t_start:.3f}s ' | |
't_dict=' + str({k: f'{v:.3}s' for k, v in t_dict.items()}) | |
) | |
if return_dict_in_generate: | |
if self.config.is_encoder_decoder: | |
return SampleEncoderDecoderOutput( | |
sequences=input_ids, | |
scores=scores, | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return SampleDecoderOnlyOutput( | |
sequences=input_ids, | |
scores=scores, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return input_ids | |
def my_beam_search( | |
self: GenerationMixin, | |
input_ids: torch.LongTensor, | |
beam_scorer: BeamScorer, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
stopping_criteria: Optional[StoppingCriteriaList] = None, | |
max_length: Optional[int] = None, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
output_scores: Optional[bool] = None, | |
return_dict_in_generate: Optional[bool] = None, | |
synced_gpus: Optional[bool] = False, | |
**model_kwargs, | |
) -> Union[BeamSearchOutput, torch.LongTensor]: | |
# init values | |
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() | |
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() | |
if max_length is not None: | |
warnings.warn( | |
"`max_length` is deprecated in this function, use" | |
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", | |
UserWarning, | |
) | |
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |
if len(stopping_criteria) == 0: | |
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) | |
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id | |
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id | |
output_scores = output_scores if output_scores is not None else self.config.output_scores | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict_in_generate = ( | |
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate | |
) | |
batch_size = len(beam_scorer._beam_hyps) | |
num_beams = beam_scorer.num_beams | |
batch_beam_size, cur_len = input_ids.shape | |
if num_beams * batch_size != batch_beam_size: | |
raise ValueError( | |
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
beam_indices = ( | |
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None | |
) | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens | |
# of the first beam are considered to avoid sampling the exact same tokens across all beams. | |
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) | |
beam_scores[:, 1:] = -1e9 | |
beam_scores = beam_scores.view((batch_size * num_beams,)) | |
t_start = default_timer() | |
t_dict = defaultdict(lambda: 0.0) | |
this_peer_finished = False # used by synced_gpus only | |
while True: | |
if synced_gpus: | |
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. | |
# The following logic allows an early break if all peers finished generating their sequence | |
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) | |
# send 0.0 if we finished, 1.0 otherwise | |
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) | |
# did all peers finish? the reduced sum will be 0.0 then | |
if this_peer_finished_flag.item() == 0.0: | |
break | |
t0 = default_timer() | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# t_dict['prepare_inputs'] += default_timer() - t0 | |
# t0 = default_timer() | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
# t_dict['model'] += default_timer() - t0 | |
# t0 = default_timer() | |
if synced_gpus and this_peer_finished: | |
cur_len = cur_len + 1 | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
t_dict['_'] += next_token_logits.sum().cpu().numpy() | |
t_dict['model_2'] += default_timer() - t0 | |
t0 = default_timer() | |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` | |
# cannot be generated both before and after the `nn.functional.log_softmax` operation. | |
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) | |
next_token_scores = nn.functional.log_softmax( | |
next_token_logits, dim=-1 | |
) # (batch_size * num_beams, vocab_size) | |
# t_dict['_'] += next_token_scores.sum().cpu().numpy() | |
# t_dict['calc_next_token'] += default_timer() - t0 | |
# t0 = default_timer() | |
next_token_scores_processed = logits_processor(input_ids, next_token_scores) | |
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) | |
# t_dict['_'] += next_token_scores.sum().cpu().numpy() | |
# t_dict['calc_scores'] += default_timer() - t0 | |
# t0 = default_timer() | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores_processed,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
# t_dict['related_return_dict_in_generate'] += default_timer() - t0 | |
# t0 = default_timer() | |
# reshape for beam search | |
vocab_size = next_token_scores.shape[-1] | |
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) | |
# t_dict['_'] += next_token_scores.sum().cpu().numpy() | |
# t_dict['reshape'] += default_timer() - t0 | |
# t0 = default_timer() | |
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) | |
next_token_scores, next_tokens = torch.topk( | |
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True | |
) | |
# t_dict['_'] += next_token_scores.sum().cpu().numpy() | |
# t_dict['topk'] += default_timer() - t0 | |
# t0 = default_timer() | |
next_indices = torch_int_div(next_tokens, vocab_size) | |
next_tokens = next_tokens % vocab_size | |
# t_dict['sample'] += default_timer() - t0 | |
# t0 = default_timer() | |
# stateless | |
beam_outputs = beam_scorer.process( | |
input_ids, | |
next_token_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
beam_indices=beam_indices, | |
) | |
t_dict['scorer'] += default_timer() - t0 | |
t0 = default_timer() | |
beam_scores = beam_outputs["next_beam_scores"] | |
beam_next_tokens = beam_outputs["next_beam_tokens"] | |
beam_idx = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder | |
) | |
if model_kwargs["past"] is not None: | |
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) | |
if return_dict_in_generate and output_scores: | |
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) | |
# increase cur_len | |
cur_len = cur_len + 1 | |
t_dict['updates'] += default_timer() - t0 | |
if beam_scorer.is_done or stopping_criteria(input_ids, scores): | |
if not synced_gpus: | |
break | |
else: | |
this_peer_finished = True | |
sequence_outputs = beam_scorer.finalize( | |
input_ids, | |
beam_scores, | |
next_tokens, | |
next_indices, | |
pad_token_id=pad_token_id, | |
eos_token_id=eos_token_id, | |
max_length=stopping_criteria.max_length, | |
beam_indices=beam_indices, | |
) | |
print( | |
f'beam_search ' | |
f't_total={default_timer() - t_start:.3f}s ' | |
't_dict=' + str({k: f'{v:.3}s' for k, v in t_dict.items()}) | |
# ' beam_scorer=' + str({k: f'{v:.3}s' for k, v in beam_scorer.t_dict.items()}) | |
) | |
if return_dict_in_generate: | |
if not output_scores: | |
sequence_outputs["sequence_scores"] = None | |
if self.config.is_encoder_decoder: | |
return BeamSearchEncoderDecoderOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
beam_indices=sequence_outputs["beam_indices"], | |
encoder_attentions=encoder_attentions, | |
encoder_hidden_states=encoder_hidden_states, | |
decoder_attentions=decoder_attentions, | |
cross_attentions=cross_attentions, | |
decoder_hidden_states=decoder_hidden_states, | |
) | |
else: | |
return BeamSearchDecoderOnlyOutput( | |
sequences=sequence_outputs["sequences"], | |
sequences_scores=sequence_outputs["sequence_scores"], | |
scores=scores, | |
beam_indices=sequence_outputs["beam_indices"], | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
) | |
else: | |
return sequence_outputs["sequences"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE copied from: | |
# * https://github.com/huggingface/transformers/blob/main/tests/generation/test_beam_search.py | |
# * 还有一点来自 https://github.com/huggingface/transformers/blob/main/tests/test_modeling_common.py#L2609 | |
import random | |
import unittest | |
import torch | |
from transformers.generation import BeamHypotheses | |
from transformers.testing_utils import require_torch, torch_device | |
from ....utils.huggingface.beam_search import MyBeamSearchScorer | |
global_rng = random.Random() | |
def ids_tensor(shape, vocab_size, rng=None, name=None): | |
# Creates a random int32 tensor of the shape within the vocab size | |
if rng is None: | |
rng = global_rng | |
total_dims = 1 | |
for dim in shape: | |
total_dims *= dim | |
values = [] | |
for _ in range(total_dims): | |
values.append(rng.randint(0, vocab_size - 1)) | |
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() | |
def floats_tensor(shape, scale=1.0, rng=None, name=None): | |
"""Creates a random float32 tensor""" | |
if rng is None: | |
rng = global_rng | |
total_dims = 1 | |
for dim in shape: | |
total_dims *= dim | |
values = [] | |
for _ in range(total_dims): | |
values.append(rng.random() * scale) | |
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() | |
class BeamSearchTester: | |
def __init__( | |
self, | |
parent, | |
batch_size=3, | |
sequence_length=10, | |
vocab_size=99, | |
pad_token_id=0, | |
max_length=20, | |
num_beams=4, | |
length_penalty=2.0, | |
do_early_stopping=True, | |
num_beam_hyps_to_keep=2, | |
): | |
self.parent = parent | |
self.batch_size = batch_size | |
self.sequence_length = sequence_length | |
self.vocab_size = vocab_size | |
self.pad_token_id = pad_token_id | |
self.max_length = max_length | |
self.num_beams = num_beams | |
self.length_penalty = length_penalty | |
self.do_early_stopping = do_early_stopping | |
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | |
# cannot be randomely generated | |
self.eos_token_id = vocab_size + 1 | |
def prepare_beam_scorer(self, **kwargs): | |
return MyBeamSearchScorer( | |
batch_size=kwargs.get("batch_size", self.batch_size), | |
num_beams=kwargs.get("num_beams", self.num_beams), | |
device=torch_device, | |
length_penalty=kwargs.get("length_penalty", self.length_penalty), | |
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), | |
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep), | |
) | |
def prepare_inputs(self): | |
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size) | |
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device) | |
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device) | |
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True) | |
return (input_ids, next_tokens, next_indices, next_scores) | |
def check_beam_hypotheses(self, input_ids, *args): | |
# check that correct number of beam hypotheses is set in beam scorer | |
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True) | |
beam_hyp = beam_scorer._beam_hyps[0] | |
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size) | |
# check correct typn | |
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses)) | |
# check that num_beams is correctly set | |
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams) | |
# check for early stopping deactivated | |
for beam_idx in range(self.num_beams): | |
beam_hyp.add(input_ids[beam_idx], -10.0) | |
# if early stopping True -> score does not matter | |
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) | |
# re-init | |
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False) | |
beam_hyp = beam_scorer._beam_hyps[0] | |
# add `num_beams + 1` beams to change `worst_score` | |
for beam_idx in range(self.num_beams + 1): | |
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) | |
# -10.0 is removed => -9.0 is worst score | |
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty)) | |
# -5.0 is better than worst score => should not be finished | |
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) | |
# -20.0 is worse than worst score => should be finished | |
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length)) | |
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores): | |
# check too many eos tokens | |
beam_scorer = self.prepare_beam_scorer() | |
tokens = next_tokens.clone() | |
tokens[0, :] = self.eos_token_id | |
with self.parent.assertRaises(ValueError): | |
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) | |
# check all batches are done | |
beam_scorer = self.prepare_beam_scorer() | |
tokens = next_tokens.clone() | |
tokens[:, : self.num_beams] = self.eos_token_id | |
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device) | |
beam_indices = tuple(tuple(b) for b in beam_indices) | |
beam_scorer.process( | |
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices | |
) | |
# beam scorer should be done | |
self.parent.assertTrue(beam_scorer.is_done) | |
# check | |
beam_scorer = self.prepare_beam_scorer() | |
tokens = next_tokens.clone() | |
tokens[:, 1] = self.eos_token_id | |
beam_outputs = beam_scorer.process( | |
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices | |
) | |
output_scores = beam_outputs["next_beam_scores"] | |
output_tokens = beam_outputs["next_beam_tokens"] | |
output_indices = beam_outputs["next_beam_indices"] | |
def cut_expected_tensor(tensor): | |
return torch.cat([tensor[:, :1], tensor[:, 2: self.num_beams + 1]], dim=1).flatten() | |
# check all outptus | |
# cut out id of eos token and take best `num_beams` outputs | |
expected_output_tokens = cut_expected_tensor(tokens) | |
expected_output_scores = cut_expected_tensor(next_scores) | |
# add num_beams * batch_idx | |
expected_output_indices = ( | |
cut_expected_tensor(next_indices) | |
+ (torch.arange(self.num_beams * self.batch_size, | |
device=torch_device) // self.num_beams) * self.num_beams | |
) | |
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist()) | |
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist()) | |
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) | |
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer | |
expected_beam_indices = list(range(10)) | |
for batch_idx in range(self.batch_size): | |
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] | |
self.parent.assertListEqual( | |
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() | |
) | |
self.parent.assertListEqual( | |
expected_beam_indices + [correct_idx], | |
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(), | |
) | |
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): | |
# max_length should be only one more than current input_ids to check that eos is correctly appended | |
max_length = self.sequence_length + 1 | |
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False) | |
# update beams and append to input_ids | |
tokens = next_tokens.clone() | |
# first batch, first output has to finish with eos token id since scores are correctly sorted | |
tokens[0, 0] = self.eos_token_id | |
# make sure corresponding score is as good as possible to surely be picked first | |
next_scores[0, 0] = 0.0 | |
beam_outputs = beam_scorer.process( | |
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id | |
) | |
output_scores = beam_outputs["next_beam_scores"] | |
output_tokens = beam_outputs["next_beam_tokens"] | |
output_indices = beam_outputs["next_beam_indices"] | |
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) | |
# finalize | |
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device) | |
beam_indices = tuple(tuple(b) for b in beam_indices) | |
sequence_output = beam_scorer.finalize( | |
input_ids, | |
output_scores, | |
output_tokens, | |
output_indices, | |
pad_token_id=self.pad_token_id, | |
eos_token_id=self.eos_token_id, | |
max_length=max_length, | |
beam_indices=beam_indices, | |
) | |
sequences = sequence_output["sequences"] | |
sequence_scores = sequence_output["sequence_scores"] | |
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length` | |
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length]) | |
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size]) | |
# check sequence_scores | |
self.parent.assertFalse((sequence_scores > 0).any().item()) | |
# first batch has to finish with eos_token | |
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id) | |
# other batches cannot finish with eos token | |
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id) | |
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id) | |
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned | |
beam_scorer.num_beam_hyps_to_keep = self.num_beams | |
sequence_output = beam_scorer.finalize( | |
input_ids, | |
output_scores, | |
output_tokens, | |
output_indices, | |
pad_token_id=self.pad_token_id, | |
eos_token_id=self.eos_token_id, | |
max_length=max_length, | |
beam_indices=beam_indices, | |
) | |
sequences = sequence_output["sequences"] | |
sequence_scores = sequence_output["sequence_scores"] | |
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length]) | |
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size]) | |
@require_torch | |
class BeamSearchTest(unittest.TestCase): | |
def setUp(self): | |
self.beam_search_tester = BeamSearchTester(self) | |
def test_beam_hypotheses(self): | |
inputs = self.beam_search_tester.prepare_inputs() | |
self.beam_search_tester.check_beam_hypotheses(*inputs) | |
def test_beam_scorer_update(self): | |
inputs = self.beam_search_tester.prepare_inputs() | |
self.beam_search_tester.check_beam_scorer_update(*inputs) | |
def test_beam_scorer_finalize(self): | |
inputs = self.beam_search_tester.prepare_inputs() | |
self.beam_search_tester.check_beam_scores_finalize(*inputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment