Last active
June 22, 2025 15:16
-
-
Save waveletdeboshir/244b62ab800bc32f21a6bbb366cb81ba to your computer and use it in GitHub Desktop.
Jupyter for creation Whisper model without numbers. Existing models: https://huggingface.co/collections/waveletdeboshir/whisper-without-numbers-67004c5d7bf9e1a99e373d54
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/waveletdeboshir/244b62ab800bc32f21a6bbb366cb81ba/removenumberswhisper.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "yjhLPsMyLfKA" | |
}, | |
"source": [ | |
"# Remove number tokens from Whisper model and tokenizer\n", | |
"\n", | |
"my lib versions:\n", | |
"* transformers 4.46.3\n", | |
"* torch 2.4.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "08cOSlseLfKC" | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ[\"HF_HUB_CACHE\"] = \"./models/\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "5ToUGeRtLfKD" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from transformers import WhisperProcessor, WhisperTokenizer, WhisperForConditionalGeneration" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ZmWeSMCeLfKD" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Whisper size: tiny, base, small, medium, large-v2, large-v3, large-v3-turbo\n", | |
"size = \"large-v3\"\n", | |
"new_name = \"no-numbers\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "r75Xw4BWLfKD", | |
"outputId": "30e65a13-75a4-4db6-a86f-8ca1d781f86d" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Клонирование в «whisper-large-v3»...\n", | |
"remote: Enumerating objects: 78, done.\u001b[K\n", | |
"remote: Counting objects: 100% (38/38), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (38/38), done.\u001b[K\n", | |
"remote: Total 78 (delta 21), reused 0 (delta 0), pack-reused 40 (from 1)\u001b[K\n", | |
"Распаковка объектов: 100% (78/78), 1.21 МиБ | 3.32 МиБ/с, готово.\n", | |
"Фильтруется содержимое: 100% (7/7), 11.00 ГиБ | 5.00 МиБ/с, готово.\n" | |
] | |
} | |
], | |
"source": [ | |
"!mkdir models\n", | |
"!cd models && git clone https://huggingface.co/openai/whisper-{size}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "mURtVpcALfKE" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Load initial model and tokenizer\n", | |
"tokenizer = WhisperTokenizer.from_pretrained(f\"./models/whisper-{size}\")\n", | |
"model = WhisperForConditionalGeneration.from_pretrained(f\"./models/whisper-{size}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "RhbCH5tsLfKE" | |
}, | |
"source": [ | |
"# Find tokens with numbers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "5g-dJg8nLfKE" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Find token indicies for numbers\n", | |
"number_tokens = [\n", | |
" i\n", | |
" for i in range(tokenizer.vocab_size)\n", | |
" if any(c in \"0123456789\" for c in tokenizer.decode([i], add_special_tokens=False))\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ITBRITeXLfKE", | |
"outputId": "4779f787-b444-4ccd-f57b-afa7e7441303" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"426" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(number_tokens)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "AKlWluOuLfKE" | |
}, | |
"outputs": [], | |
"source": [ | |
"# If you want to remove roman numerals too (except I, V, X)\n", | |
"for roman in [\"II\", \"III\", \"IV\", \"VI\", \"VII\", \"VIII\", \"IX\", \"XI\", \"XII\", \"XIII\", \"XIV\", \"XV\", \"XVI\", \"XVII\", \"XVIII\", \"XIX\", \"XX\"]:\n", | |
" t = tokenizer.encode(roman, add_special_tokens=False)\n", | |
" if len(t) == 1:\n", | |
" number_tokens.append(t[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "6b5gqg7jLfKE", | |
"outputId": "6c097469-1751-4b14-b4db-6a69dde96ed3" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"431" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(number_tokens)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "D4JnhX8dLfKF", | |
"outputId": "7095ea0c-882a-4e88-d3b6-ec1c840d74c1" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"51866" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.proj_out.out_features" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "aF7RbEnfLfKF" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Token indicies to keep\n", | |
"kept_ids = []\n", | |
"for n in range(model.proj_out.out_features):\n", | |
" if n not in number_tokens:\n", | |
" kept_ids.append(n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ONbm310bLfKF", | |
"outputId": "783960a0-719b-4daa-b560-75714a08bd0a" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"51435" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(kept_ids)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SsOB6OsYLfKF" | |
}, | |
"source": [ | |
"# Update model weights" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "-Df70724LfKF" | |
}, | |
"outputs": [], | |
"source": [ | |
"import copy\n", | |
"new_model = copy.deepcopy(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "htCG5SjsLfKF" | |
}, | |
"outputs": [], | |
"source": [ | |
"new_size = len(kept_ids)\n", | |
"\n", | |
"# New embedding layer\n", | |
"\n", | |
"endoftext_idx = tokenizer.convert_tokens_to_ids(\"<|endoftext|>\")\n", | |
"new_emb = torch.nn.Embedding(\n", | |
" new_size,\n", | |
" model.model.decoder.embed_tokens.embedding_dim,\n", | |
" padding_idx=kept_ids.index(endoftext_idx) # new idx of <|endoftext|> token\n", | |
")\n", | |
"\n", | |
"# New proj_out layer\n", | |
"new_head = torch.nn.Linear(\n", | |
" in_features=model.proj_out.in_features,\n", | |
" out_features=new_size,\n", | |
" bias=False\n", | |
")\n", | |
"\n", | |
"# Copying weights\n", | |
"for new_id, old_id in enumerate(kept_ids):\n", | |
" new_emb.weight.data[new_id] = model.model.decoder.embed_tokens.weight.data[old_id]\n", | |
" new_head.weight.data[new_id] = model.proj_out.weight.data[old_id]\n", | |
"\n", | |
"# Change layers in model\n", | |
"new_model.model.decoder.embed_tokens = new_emb\n", | |
"new_model.proj_out = new_head" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "evfARbfBLfKG" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Change model config\n", | |
"\n", | |
"new_model.config.__dict__['vocab_size'] = new_size\n", | |
"new_model.config.__dict__['_name_or_path'] = f'waveletdeboshir/whisper-{size}-{new_name}'\n", | |
"\n", | |
"\n", | |
"\n", | |
"new_model.config.__dict__[\"bos_token_id\"] = kept_ids.index(model.config.__dict__[\"bos_token_id\"])\n", | |
"new_model.config.__dict__[\"decoder_start_token_id\"] = kept_ids.index(model.config.__dict__[\"decoder_start_token_id\"])\n", | |
"new_model.config.__dict__[\"eos_token_id\"] = kept_ids.index(model.config.__dict__[\"eos_token_id\"])\n", | |
"new_model.config.__dict__[\"pad_token_id\"] = kept_ids.index(model.config.__dict__[\"pad_token_id\"])\n", | |
"new_model.config.__dict__[\"suppress_tokens\"] = []\n", | |
"new_model.config.__dict__[\"forced_decoder_ids\"] = [\n", | |
" [\n", | |
" 1,\n", | |
" kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|en|>\")) # language\n", | |
" ],\n", | |
" [\n", | |
" 2,\n", | |
" kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|transcribe|>\")) # <|transcribe|>\n", | |
" ],\n", | |
" [\n", | |
" 3,\n", | |
" kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|notimestamps|>\")) # <|notimestamps|>\n", | |
" ]\n", | |
"]\n", | |
"\n", | |
"beg_sup = []\n", | |
"for t in model.config.__dict__['begin_suppress_tokens']:\n", | |
" if t in kept_ids:\n", | |
" beg_sup.append(kept_ids.index(t))\n", | |
"new_model.config.__dict__['begin_suppress_tokens'] = beg_sup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "5OdE5pxzLfKG" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Change generation config\n", | |
"\n", | |
"beg_sup = []\n", | |
"for t in model.generation_config.__dict__['begin_suppress_tokens']:\n", | |
" if t in kept_ids:\n", | |
" beg_sup.append(kept_ids.index(t))\n", | |
"new_model.generation_config.__dict__['begin_suppress_tokens'] = beg_sup\n", | |
"\n", | |
"new_model.generation_config.__dict__[\"bos_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"bos_token_id\"])\n", | |
"new_model.generation_config.__dict__[\"decoder_start_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"decoder_start_token_id\"])\n", | |
"new_model.generation_config.__dict__[\"eos_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"eos_token_id\"])\n", | |
"new_model.generation_config.__dict__[\"forced_decoder_ids\"] = [\n", | |
" [\n", | |
" 1,\n", | |
" None\n", | |
" ],\n", | |
" [\n", | |
" 2,\n", | |
" kept_ids.index(tokenizer.convert_tokens_to_ids(\"<|transcribe|>\"))\n", | |
" ]\n", | |
" ]\n", | |
"\n", | |
"new_lang_to_id = {}\n", | |
"for key, value in model.generation_config.__dict__[\"lang_to_id\"].items():\n", | |
" if value in kept_ids:\n", | |
" new_lang_to_id[key] = kept_ids.index(value)\n", | |
"new_model.generation_config.__dict__[\"lang_to_id\"] = new_lang_to_id\n", | |
"\n", | |
"new_model.generation_config.__dict__[\"no_timestamps_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"no_timestamps_token_id\"])\n", | |
"new_model.generation_config.__dict__[\"pad_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"pad_token_id\"])\n", | |
"new_model.generation_config.__dict__[\"prev_sot_token_id\"] = kept_ids.index(model.generation_config.__dict__[\"prev_sot_token_id\"])\n", | |
"new_model.generation_config.__dict__[\"suppress_tokens\"] = []\n", | |
"new_model.generation_config.__dict__[\"task_to_id\"] = {\n", | |
" key: kept_ids.index(value) for key, value in model.generation_config.__dict__[\"task_to_id\"].items()\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "P5JvnUMKLfKG", | |
"outputId": "28d8254b-9909-4858-e2e7-90b074550d6a" | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/daryavozdaeva/Work/anaconda3/envs/torch-env/lib/python3.10/site-packages/transformers/modeling_utils.py:2817: UserWarning: Moving the following attributes in the config to the generation config: {'max_length': 448, 'suppress_tokens': [], 'begin_suppress_tokens': [210, 49826]}. You are seeing this warning because you've set generation parameters in the model config, as opposed to in the generation config.\n", | |
" warnings.warn(\n" | |
] | |
} | |
], | |
"source": [ | |
"new_model.save_pretrained(f\"models/whisper-{size}-{new_name}\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5ul1_b4lLfKG" | |
}, | |
"source": [ | |
"# Change tokenizer\n", | |
"\n", | |
"At first it's better to copy all tokenizer files to separate folder `models/tokenizer`.\n", | |
"\n", | |
"Next we create new folder to save changed tokenizer there." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "BLu9uAUlLfKG" | |
}, | |
"outputs": [], | |
"source": [ | |
"import json" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "uit2dCmCLfKG" | |
}, | |
"outputs": [], | |
"source": [ | |
"target_folder = \"tokenizer-nonumbers\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ZGnu62MJLfKG", | |
"outputId": "3179787b-2b37-4fa8-f492-14e3937aee91" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mkdir: ./models/tokenizer-nonumbers: File exists\n", | |
"mkdir: ./models/tokenizer: File exists\n" | |
] | |
} | |
], | |
"source": [ | |
"!mkdir ./models/{target_folder}\n", | |
"\n", | |
"!mkdir ./models/tokenizer\n", | |
"!cp ./models/whisper-{size}/added_tokens.json ./models/tokenizer/\n", | |
"!cp ./models/whisper-{size}/merges.txt ./models/tokenizer/\n", | |
"!cp ./models/whisper-{size}/special_tokens_map.json ./models/tokenizer/\n", | |
"!cp ./models/whisper-{size}/tokenizer.json ./models/tokenizer/\n", | |
"!cp ./models/whisper-{size}/tokenizer_config.json ./models/tokenizer/\n", | |
"!cp ./models/whisper-{size}/vocab.json ./models/tokenizer/" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "1-qpKVe0LfKH" | |
}, | |
"source": [ | |
"Now we will change ids of tokens in every file" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "IYkS5hhALfKH" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Added tokens\n", | |
"with open(\"./models/tokenizer/added_tokens.json\", \"r\") as f:\n", | |
" added_tokens = json.load(f)\n", | |
"\n", | |
"ch_added_tokens = {}\n", | |
"for key, value in added_tokens.items():\n", | |
" if value in kept_ids:\n", | |
" ch_added_tokens[key] = kept_ids.index(value)\n", | |
"\n", | |
"with open(f\"./models/{target_folder}/added_tokens.json\", \"w\") as f:\n", | |
" json.dump(ch_added_tokens, f, indent=4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "8VySManzLfKH" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Special tokens map\n", | |
"with open(\"./models/tokenizer/special_tokens_map.json\", \"r\") as f:\n", | |
" special_tokens_map = json.load(f)\n", | |
"\n", | |
"special_tokens_map[\"additional_special_tokens\"] = [\"<|endoftext|>\"] + list(ch_added_tokens.keys())\n", | |
"with open(f\"./models/{target_folder}/special_tokens_map.json\", \"w\") as f:\n", | |
" json.dump(special_tokens_map, f, indent=4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "83HypuxfLfKH" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Tokenizer config\n", | |
"with open(\"./models/tokenizer/tokenizer_config.json\", \"r\") as f:\n", | |
" tok_config = json.load(f)\n", | |
"\n", | |
"\n", | |
"ch_added_tokens_decoder = {}\n", | |
"for key, value in tok_config[\"added_tokens_decoder\"].items():\n", | |
" if int(key) in kept_ids:\n", | |
" ch_added_tokens_decoder[str(kept_ids.index(int(key)))] = value\n", | |
"\n", | |
"tok_config[\"added_tokens_decoder\"] = ch_added_tokens_decoder\n", | |
"tok_config[\"additional_special_tokens\"] = [\"<|endoftext|>\"] + list(ch_added_tokens.keys())\n", | |
"\n", | |
"with open(f\"./models/{target_folder}/tokenizer_config.json\", \"w\") as f:\n", | |
" json.dump(tok_config, f, indent=4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "SUuHp_gHLfKH" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Tokenizer\n", | |
"with open(\"./models/tokenizer/tokenizer.json\", \"r\") as f:\n", | |
" tok = json.load(f)\n", | |
"\n", | |
"# change added tokens\n", | |
"ch_added_tokens = []\n", | |
"for t in tok[\"added_tokens\"]:\n", | |
" if t[\"id\"] in kept_ids:\n", | |
" t[\"id\"] = kept_ids.index(t[\"id\"])\n", | |
" ch_added_tokens.append(t)\n", | |
"\n", | |
"tok[\"added_tokens\"] = ch_added_tokens\n", | |
"\n", | |
"# change vocab\n", | |
"ch_vocab = {}\n", | |
"for key, value in tok[\"model\"][\"vocab\"].items():\n", | |
" if value in kept_ids:\n", | |
" ch_vocab[key] = kept_ids.index(value)\n", | |
"\n", | |
"tok[\"model\"][\"vocab\"] = ch_vocab\n", | |
"\n", | |
"# change post processor\n", | |
"ch_post = {}\n", | |
"for key, value in tok[\"post_processor\"][\"special_tokens\"].items():\n", | |
" if value[\"ids\"][0] in kept_ids:\n", | |
" value[\"ids\"][0] = kept_ids.index(value[\"ids\"][0])\n", | |
" ch_post[key] = value\n", | |
"\n", | |
"with open(f\"./models/{target_folder}/tokenizer.json\", \"w\") as f:\n", | |
" json.dump(tok, f, indent=4, ensure_ascii=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "QlWeJHvzLfKH" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Vocab\n", | |
"with open(f\"./models/{target_folder}/vocab.json\", \"w\") as f:\n", | |
" json.dump(ch_vocab, f, indent=4, ensure_ascii=True)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GGukXC4NLfKH" | |
}, | |
"source": [ | |
"Merges file" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "y2Gr3szzLfKI" | |
}, | |
"outputs": [], | |
"source": [ | |
"with open(\"./models/tokenizer/merges.txt\", \"r\") as f:\n", | |
" merges = f.read().split(\"\\n\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "fvDAPljJLfKI" | |
}, | |
"outputs": [], | |
"source": [ | |
"not_found = []\n", | |
"not_found_merged_tokens = []\n", | |
"found = []\n", | |
"\n", | |
"for merge in merges[1:-1]:\n", | |
" m = merge.split()\n", | |
" if (m[0] not in ch_vocab.keys() or m[1] not in ch_vocab.keys() or m[0] in not_found_merged_tokens or m[1] in not_found_merged_tokens) and (m[0] + m[1] not in ch_vocab.keys()):\n", | |
" not_found.append(merge)\n", | |
" not_found_merged_tokens.append(m[0] + m[1])\n", | |
" else:\n", | |
" found.append(merge)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "1H5HdL2CLfKI", | |
"outputId": "e23e4756-d05c-47af-dd70-9f73c9afb3d2" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"49583" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(found)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "LX6cCI-WLfKI" | |
}, | |
"outputs": [], | |
"source": [ | |
"with open(f\"./models/{target_folder}/merges.txt\", \"w\") as f:\n", | |
" f.write(\"\\n\".join(found))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "dpM8x4rtLfKI" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Load changed tokenizer from folder\n", | |
"changed_tok = WhisperTokenizer.from_pretrained(f\"./models/{target_folder}/\", local_files_only=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "0vyoXcWRLfKJ", | |
"outputId": "e55af0e0-0a48-4d5f-de4b-b4f93e6b3c97" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[49827, 49933, 8458, 34531, 210, None, None, None, None, None, 49826]" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"changed_tok.encode(\"Текст 12345\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "7cBTn84zLfKJ" | |
}, | |
"source": [ | |
"# Try new model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "_lSobjYKLfKJ" | |
}, | |
"outputs": [], | |
"source": [ | |
"# We need to copy new tokenizer files\n", | |
"# normalizer file and preprocessor config from original model\n", | |
"!cp ./models/{target_folder}/* ./models/whisper-{size}-{new_name}/\n", | |
"!cp ./models/whisper-{size}/normalizer.json ./models/whisper-{size}-{new_name}/\n", | |
"!cp ./models/whisper-{size}/preprocessor_config.json ./models/whisper-{size}-{new_name}/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"referenced_widgets": [ | |
"57f391a5689f452bad263e42a8622adf" | |
] | |
}, | |
"id": "zt1X-QKWLfKJ", | |
"outputId": "3d78b56b-e24f-4d5d-805f-0dc1d68c37db" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "57f391a5689f452bad263e42a8622adf", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Load new model, processor and tokenizer from folder\n", | |
"\n", | |
"tokenizer = WhisperTokenizer.from_pretrained(f\"./models/whisper-{size}-{new_name}/\", local_files_only=True)\n", | |
"model = WhisperForConditionalGeneration.from_pretrained(f\"./models/whisper-{size}-{new_name}\", local_files_only=True)\n", | |
"preprocessor = WhisperProcessor.from_pretrained(f\"./models/whisper-{size}-{new_name}\", local_files_only=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jpz3nInpLfKK" | |
}, | |
"source": [ | |
"Check if all works on some test file" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "KUIexZMiLfKK" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torchaudio" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "UyQIo_2PLfKK" | |
}, | |
"outputs": [], | |
"source": [ | |
"wav, sr = torchaudio.load(\"test.wav\")\n", | |
"\n", | |
"if sr != 16000:\n", | |
" wav = torchaudio.functional.resample(wav, sr, 16000)\n", | |
"\n", | |
"processed = preprocessor(wav[0], sampling_rate=16000, return_tensors=\"pt\")\n", | |
"\n", | |
"predicted_ids = model.generate(processed.input_features, language=\"ru\", task=\"transcribe\")\n", | |
"\n", | |
"transcriptions = preprocessor.batch_decode(predicted_ids, skip_special_tokens=False)\n", | |
"\n", | |
"print(transcriptions)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "4YTber9mLfKK" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "torch-env", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment