|
MODEL = "medium.en" |
|
|
|
|
|
|
|
from transformers import WhisperModel, WhisperTokenizer |
|
|
|
model_train = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}").cuda().train() |
|
model_base = WhisperModel.from_pretrained(f"openai/whisper-{MODEL}").cuda().eval() |
|
|
|
|
|
|
|
from datasets import load_dataset |
|
from transformers import WhisperProcessor |
|
|
|
ds = load_dataset("google/fleurs", "en_us", split="train", trust_remote_code=True) |
|
processor = WhisperProcessor.from_pretrained(f"openai/whisper-{MODEL}") |
|
|
|
|
|
|
|
def get_sample(example): |
|
waveform = example["audio"]["array"] |
|
sampling_rate = example["audio"]["sampling_rate"] |
|
|
|
# Use the model and processor to transcribe the audio: |
|
input_features = processor( |
|
waveform, sampling_rate=sampling_rate, return_tensors="pt" |
|
).input_features |
|
|
|
return { |
|
"length": len(waveform) / sampling_rate, |
|
"input_features": input_features, |
|
"input_ids": processor.tokenizer.encode(example["raw_transcription"].lower()) |
|
} |
|
|
|
if not( ".en" in MODEL): |
|
print(processor.get_decoder_prompt_ids(language="english",task="transcribe")) |
|
|
|
[processor.tokenizer.decode(i) for i in get_sample(ds[1])["input_ids"]] |
|
|
|
|
|
|
|
import torch |
|
from tqdm import tqdm |
|
from torch import nn |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
def compute_partially_encoder(model, data, n_audio_ctx): |
|
diffy = 2*n_audio_ctx - data.shape[2] |
|
|
|
if diffy > 0: |
|
data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0) |
|
elif diffy < 0: |
|
data = data[:,:,:diffy] |
|
|
|
if n_audio_ctx == 1500: |
|
return model.encoder(data).last_hidden_state |
|
|
|
input_embeds = nn.functional.gelu(model.encoder.conv1(data)) |
|
input_embeds = nn.functional.gelu(model.encoder.conv2(input_embeds)) |
|
input_embeds = input_embeds.permute(0, 2, 1) |
|
|
|
embed_pos = model.encoder.embed_positions.weight[:n_audio_ctx] |
|
|
|
hidden_states = input_embeds + embed_pos |
|
hidden_states = nn.functional.dropout(hidden_states, p=model.encoder.dropout, training=model.encoder.training) |
|
|
|
for idx, encoder_layer in enumerate(model.encoder.layers): |
|
to_drop = False |
|
if model.encoder.training: |
|
dropout_probability = torch.rand([]) |
|
if dropout_probability < model.encoder.layerdrop: |
|
to_drop = True |
|
|
|
if to_drop: |
|
layer_outputs = (None, None) |
|
else: |
|
if model.encoder.gradient_checkpointing and model.encoder.training: |
|
layer_outputs = model.encoder._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
None, |
|
None, |
|
False, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
None, |
|
layer_head_mask=None, |
|
output_attentions=False, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
hidden_states = model.encoder.layer_norm(hidden_states) |
|
return hidden_states |
|
|
|
|
|
def compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example): |
|
optimizer.zero_grad() |
|
|
|
n_ctx = int(round((1500.0 / 30.0) * example["length"] )) |
|
|
|
extra_ctx = torch.randint(-min(64, n_ctx // 3), min(64, n_ctx // 3), (1,)).item() |
|
n_ctx += extra_ctx |
|
|
|
input_features = example["input_features"].cuda() |
|
input_ids = torch.tensor([example["input_ids"]], dtype=torch.long).cuda() |
|
|
|
encoder_hidden_states_partial = compute_partially_encoder(model_train, input_features, n_ctx) |
|
output_partial = model_train.decoder( |
|
input_ids=input_ids, |
|
encoder_hidden_states=encoder_hidden_states_partial, |
|
output_hidden_states=True |
|
) |
|
|
|
with torch.no_grad(): |
|
encoder_hidden_states_full = compute_partially_encoder(model_base, input_features, 1500) |
|
output_full = model_base.decoder( |
|
input_ids=input_ids, |
|
encoder_hidden_states=encoder_hidden_states_full, |
|
output_hidden_states=True |
|
) |
|
|
|
loss = criterion( |
|
#output_partial.hidden_states[-1], |
|
#output_full.hidden_states[-1] |
|
torch.cat(output_partial.hidden_states, 0), |
|
torch.cat(output_full.hidden_states, 0) |
|
) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
return loss |
|
|
|
|
|
criterion = torch.nn.MSELoss() |
|
optimizer = torch.optim.Adam(model_train.parameters(), lr=1e-6) |
|
|
|
|
|
writer = SummaryWriter() |
|
writer.add_text("name", f"{MODEL} v3") |
|
|
|
num_length = 0 |
|
step = 0 |
|
for epoch in range(8): |
|
pbar = tqdm(ds.shuffle(seed=epoch)) |
|
for example in pbar: |
|
example = get_sample(example) |
|
if example["length"] > 29.0: continue |
|
|
|
loss = compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example) |
|
step += 1 |
|
num_length += example["length"] |
|
|
|
writer.add_scalar("loss/train", loss.item(), step) |
|
writer.add_scalar("length/train", num_length, step) |
|
writer.add_scalar("epoch/train", epoch, step) |
|
|
|
pbar.set_description(f"Epoch {epoch}, Loss: {loss.item()}") |
|
|
|
|
|
|
|
from datasets import load_dataset |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
|
|
# Select an audio file and read it: |
|
ds_eval = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
|
|
# Load the Whisper model in Hugging Face format: |
|
model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}").eval().cuda() |
|
|
|
for i in range(64): |
|
audio_sample = ds_eval[i]["audio"] |
|
waveform = audio_sample["array"] |
|
sampling_rate = audio_sample["sampling_rate"] |
|
|
|
# Use the model and processor to transcribe the audio: |
|
input_features = processor( |
|
waveform, sampling_rate=sampling_rate, return_tensors="pt" |
|
).input_features.cuda() |
|
|
|
model.model = model_base.eval().cuda() |
|
predicted_ids_base = model.generate(input_features) |
|
model.model = model_train.eval().cuda() |
|
predicted_ids_train = model.generate(input_features) |
|
|
|
# Decode token ids to text |
|
transcription = processor.batch_decode([predicted_ids_base[0], predicted_ids_train[0]], skip_special_tokens=True) |
|
|
|
print(f"\n\nGrndTr: {ds_eval[i]['text'].lower()}\nModelB:{transcription[0]}\nModelT:{transcription[1]}") |
|
|
|
|
|
|
|
model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{MODEL}").eval().cpu() |
|
model.model = model_train.eval().cpu() |
|
|
|
model.save_pretrained(f"model_train-{MODEL}3") |
|
|
|
import shutil |
|
shutil.make_archive(f"model_train-{MODEL}3", 'zip', f"model_train-{MODEL}3") |
|
|
|
|
|
|
|
torch.save(model, f"model_train-{MODEL}3.pt") |