Skip to content

Instantly share code, notes, and snippets.

@Jourdelune
Last active August 24, 2024 15:09
Show Gist options
  • Save Jourdelune/50209425671f37b72f0794865ddd64d0 to your computer and use it in GitHub Desktop.
Save Jourdelune/50209425671f37b72f0794865ddd64d0 to your computer and use it in GitHub Desktop.
Code to finetune a lyrics detector
import evaluate
import numpy as np
from audiotools import AudioSignal
from datasets import Audio, load_dataset
from transformers import (
AutoFeatureExtractor,
AutoModelForAudioClassification,
Trainer,
TrainingArguments,
)
from pipelines.transcript import TranscriptModel
ds = load_dataset("lewtun/music_genres", split="train")
ds = ds.shuffle(seed=42)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
accuracy = evaluate.load("accuracy")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
asr = TranscriptModel()
def prepare_audio(row, idx):
row["lyrics"] = 0
try:
audio, sr = ds[idx]["audio"]["array"], ds[idx]["audio"]["sampling_rate"]
audio = AudioSignal(audio, sr)
except Exception as e:
print(f"Error processing audio: {e}")
return row
# convert to mono and resample
audio = audio.to_mono()
audio = audio.resample(16000)
transcription = asr.contain_lyrics(audio.numpy()[0, 0])
row["lyrics"] = int(transcription)
return row
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=feature_extractor.sampling_rate,
max_length=16000 * 30,
truncation=True,
)
return inputs
ds = ds.map(prepare_audio, with_indices=True)
ds = ds.map(preprocess_function, remove_columns="audio", batched=True)
ds = ds.rename_column("lyrics", "label")
del asr
ds = ds.train_test_split(test_size=0.1)
label2id = {"Lyrics": 0}
id2label = {0: "Lyrics"}
num_labels = 1
model = AutoModelForAudioClassification.from_pretrained(
"facebook/wav2vec2-base",
num_labels=num_labels,
label2id=label2id,
id2label=id2label,
)
training_args = TrainingArguments(
output_dir="lyrics_detect",
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=3e-5,
per_device_train_batch_size=1,
gradient_accumulation_steps=128,
per_device_eval_batch_size=1,
num_train_epochs=10,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds["train"],
eval_dataset=ds["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metrics,
)
trainer.train()
trainer.push_to_hub("WaveGenAI/lyrics-detection")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment