Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save indiejoseph/fe6abdfe2a0625f3185a1a573445da80 to your computer and use it in GitHub Desktop.
Save indiejoseph/fe6abdfe2a0625f3185a1a573445da80 to your computer and use it in GitHub Desktop.
Whisper perplexity
import torchaudio
def eval(audio, text):
# convert audio to 16000 sample rate
audio = torchaudio.transforms.Resample(orig_freq=44100, new_freq=16000)(torch.tensor(audio).unsqueeze(0)).squeeze()
# process text
tokenized_seq = torch.tensor([processor.tokenizer(text, add_special_tokens=True).input_ids]).to(device)
decoder_input_ids = tokenized_seq[:, 1:]
decoder_input_ids_right_shifted = tokenized_seq[:, :-1]
# process audio
processed_in = processor(audio, sampling_rate=16000, return_tensors="pt").to(device)
with torch.no_grad():
output = model.forward(input_features=processed_in.input_features, decoder_input_ids=decoder_input_ids_right_shifted)
# Convert logits to log-probabilities:
log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
# Take probabilities for the ground-truth tokens:
log_prob = log_prob_all.take_along_dim(decoder_input_ids[..., None], dim=-1)
# Compute perplexity:
perplexity = torch.exp(-log_prob.mean())
return perplexity.item()
eval(audio, text)
@indiejoseph
Copy link
Author

Asked ChatGPT 4o to generate a version thats support batch audios input:

import torch
import torchaudio

def eval_batch(audio_batch, text_batch, processor, model, device):
    """
    Evaluate a batch of audio samples and corresponding text.
    
    Parameters:
    - audio_batch: List of audio arrays
    - text_batch: List of corresponding text transcripts
    - processor: The processor used for tokenization and feature extraction
    - model: The model used for evaluation
    - device: Device to run inference on ("cpu" or "cuda")
    
    Returns:
    - List of perplexity scores for each (audio, text) pair
    """
    assert len(audio_batch) == len(text_batch), "Audio and text batch must have the same length"
    
    # Convert audio to 16000 sample rate
    resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=16000)
    processed_audio = [resampler(torch.tensor(audio).unsqueeze(0)).squeeze() for audio in audio_batch]
    
    # Process text
    tokenized_seqs = [processor.tokenizer(text, add_special_tokens=True).input_ids for text in text_batch]
    tokenized_tensors = [torch.tensor(seq) for seq in tokenized_seqs]
    
    pad_token_id = processor.tokenizer.pad_token_id
    
    decoder_input_ids = torch.nn.utils.rnn.pad_sequence(
        [seq[1:] for seq in tokenized_tensors], batch_first=True, padding_value=pad_token_id
    ).to(device)
    
    decoder_input_ids_right_shifted = torch.nn.utils.rnn.pad_sequence(
        [seq[:-1] for seq in tokenized_tensors], batch_first=True, padding_value=pad_token_id
    ).to(device)
    
    # Process audio
    processed_inputs = processor(processed_audio, sampling_rate=16000, return_tensors="pt", padding=True).to(device)
    
    with torch.no_grad():
        output = model.forward(
            input_features=processed_inputs.input_features, 
            decoder_input_ids=decoder_input_ids_right_shifted
        )
    
    # Convert logits to log-probabilities
    log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
    
    # Take probabilities for the ground-truth tokens
    log_prob = log_prob_all.gather(dim=-1, index=decoder_input_ids.unsqueeze(-1)).squeeze(-1)
    
    # Mask out padding tokens
    mask = decoder_input_ids != pad_token_id
    masked_log_prob = log_prob * mask
    
    # Compute perplexity for each sequence
    perplexities = torch.exp(-(masked_log_prob.sum(dim=-1) / mask.sum(dim=-1)))
    
    return perplexities.tolist()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment