Last active
August 24, 2024 15:09
-
-
Save Jourdelune/50209425671f37b72f0794865ddd64d0 to your computer and use it in GitHub Desktop.
Code to finetune a lyrics detector
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import evaluate | |
import numpy as np | |
from audiotools import AudioSignal | |
from datasets import Audio, load_dataset | |
from transformers import ( | |
AutoFeatureExtractor, | |
AutoModelForAudioClassification, | |
Trainer, | |
TrainingArguments, | |
) | |
from pipelines.transcript import TranscriptModel | |
ds = load_dataset("lewtun/music_genres", split="train") | |
ds = ds.shuffle(seed=42) | |
ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) | |
accuracy = evaluate.load("accuracy") | |
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
asr = TranscriptModel() | |
def prepare_audio(row, idx): | |
row["lyrics"] = 0 | |
try: | |
audio, sr = ds[idx]["audio"]["array"], ds[idx]["audio"]["sampling_rate"] | |
audio = AudioSignal(audio, sr) | |
except Exception as e: | |
print(f"Error processing audio: {e}") | |
return row | |
# convert to mono and resample | |
audio = audio.to_mono() | |
audio = audio.resample(16000) | |
transcription = asr.contain_lyrics(audio.numpy()[0, 0]) | |
row["lyrics"] = int(transcription) | |
return row | |
def compute_metrics(eval_pred): | |
predictions = np.argmax(eval_pred.predictions, axis=1) | |
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids) | |
def preprocess_function(examples): | |
audio_arrays = [x["array"] for x in examples["audio"]] | |
inputs = feature_extractor( | |
audio_arrays, | |
sampling_rate=feature_extractor.sampling_rate, | |
max_length=16000 * 30, | |
truncation=True, | |
) | |
return inputs | |
ds = ds.map(prepare_audio, with_indices=True) | |
ds = ds.map(preprocess_function, remove_columns="audio", batched=True) | |
ds = ds.rename_column("lyrics", "label") | |
del asr | |
ds = ds.train_test_split(test_size=0.1) | |
label2id = {"Lyrics": 0} | |
id2label = {0: "Lyrics"} | |
num_labels = 1 | |
model = AutoModelForAudioClassification.from_pretrained( | |
"facebook/wav2vec2-base", | |
num_labels=num_labels, | |
label2id=label2id, | |
id2label=id2label, | |
) | |
training_args = TrainingArguments( | |
output_dir="lyrics_detect", | |
eval_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=3e-5, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=128, | |
per_device_eval_batch_size=1, | |
num_train_epochs=10, | |
warmup_ratio=0.1, | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
push_to_hub=True, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=ds["train"], | |
eval_dataset=ds["test"], | |
tokenizer=feature_extractor, | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |
trainer.push_to_hub("WaveGenAI/lyrics-detection") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment