Created
February 15, 2019 19:28
-
-
Save bfarzin/d25b1b0076f2bb6eee3246dcc365e970 to your computer and use it in GitHub Desktop.
AdaptiveSoftmax and Adaptive Embedding example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Break out the AdaptiveLogSoftmaxWithLoss into two parts:\n", | |
"\n", | |
"* Forward pass with predictions\n", | |
"* Loss calculation (with right preds, the usual loss components will flow through)\n", | |
"* Add adaptive embedding so that you can \"tie weights\" between the two" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%reload_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from fastai import *\n", | |
"from fastai.text import *\n", | |
"\n", | |
"import os\n", | |
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"path = untar_data(URLs.IMDB_SAMPLE)\n", | |
"data = TextLMDataBunch.from_csv(path, 'texts.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ProjectedAdaptiveEmbedding(nn.Module):\n", | |
" def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):\n", | |
" super().__init__()\n", | |
" self.n_token,self.d_embed,self.d_proj,self.div_val = n_token,d_embed,d_proj,div_val\n", | |
" self.cutoffs = cutoffs + [n_token]\n", | |
" self.cutoff_ends = [0] + self.cutoffs\n", | |
" self.emb_scale = d_proj ** 0.5\n", | |
"\n", | |
" self.emb_layers = nn.ModuleList()\n", | |
" self.emb_projs = nn.ParameterList()\n", | |
" \n", | |
" if div_val == 1:\n", | |
" self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax>0))\n", | |
" if d_proj != d_embed: self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))\n", | |
" else:\n", | |
" for i in range(len(self.cutoffs)):\n", | |
" l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]\n", | |
" d_emb_i = int(d_embed // (div_val ** i))\n", | |
" self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))\n", | |
" self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))\n", | |
" ## init - without you get nan\n", | |
" for x in self.emb_projs:\n", | |
" nn.init.normal_(x,0,0.2)\n", | |
" \n", | |
" def forward(self, inp):\n", | |
" if self.div_val == 1:\n", | |
" embed = self.emb_layers[0](inp)\n", | |
" if self.d_proj != self.d_embed: embed = F.linear(embed, self.emb_projs[0])\n", | |
" else:\n", | |
" inp_flat = inp.view(-1)\n", | |
" emb_flat = self.emb_layers[0].weight.new_zeros((inp_flat.size(0), self.d_proj))\n", | |
" for i in range(len(self.cutoffs)):\n", | |
" l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]\n", | |
"\n", | |
" mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)\n", | |
" indices_i = mask_i.nonzero().squeeze()\n", | |
" if indices_i.numel() == 0: continue\n", | |
"\n", | |
" inp_i = inp_flat.index_select(0, indices_i) - l_idx\n", | |
" emb_i = self.emb_layers[i](inp_i)\n", | |
" emb_i = F.linear(emb_i, self.emb_projs[i])\n", | |
" emb_flat.index_copy_(0, indices_i, emb_i)\n", | |
"\n", | |
" embed = emb_flat.view(*inp.size(), self.d_proj)\n", | |
" embed.mul_(self.emb_scale)\n", | |
" \n", | |
" return embed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ProjectedAdaptiveLogSoftmax(nn.Module):\n", | |
" def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1):\n", | |
" super().__init__()\n", | |
" self.n_token,self.d_embed,self.d_proj,self.div_val = n_token,d_embed,d_proj,div_val\n", | |
" self.cutoffs = cutoffs + [n_token]\n", | |
" self.cutoff_ends = [0] + self.cutoffs\n", | |
"\n", | |
" self.shortlist_size = self.cutoffs[0]\n", | |
" self.n_clusters = len(self.cutoffs) - 1\n", | |
" self.head_size = self.shortlist_size + self.n_clusters\n", | |
"\n", | |
" if self.n_clusters > 0:\n", | |
" self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))\n", | |
" self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))\n", | |
"\n", | |
" self.out_layers = nn.ModuleList()\n", | |
" self.out_projs = nn.ParameterList()\n", | |
"\n", | |
" if div_val == 1:\n", | |
" for i in range(len(self.cutoffs)):\n", | |
" if d_proj != d_embed: self.out_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))\n", | |
" else: self.out_projs.append(None)\n", | |
" self.out_layers.append(nn.Linear(d_embed, n_token))\n", | |
" else:\n", | |
" for i in range(len(self.cutoffs)):\n", | |
" l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]\n", | |
" d_emb_i = int(d_embed // (div_val ** i))\n", | |
" self.out_projs.append( nn.Parameter(torch.Tensor(d_proj, d_emb_i)))\n", | |
" self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))\n", | |
"\n", | |
" ##default init?\n", | |
" for x in self.out_projs:\n", | |
" nn.init.normal_(x,0,0.2)\n", | |
" \n", | |
" def _compute_logit(self, hidden, weight, bias, proj):\n", | |
" if proj is None: \n", | |
" logit = F.linear(hidden, weight, bias=bias)\n", | |
" else:\n", | |
" proj_hid = F.linear(hidden, proj.t().contiguous())\n", | |
" logit = F.linear(proj_hid, weight, bias=bias)\n", | |
" return logit\n", | |
" \n", | |
" def forward(self, input):\n", | |
" if self.n_clusters == 0:\n", | |
" logit = self._compute_logit(input, self.out_layers[0].weight,self.out_layers[0].bias, self.out_projs[0])\n", | |
" out = F.log_softmax(logit)\n", | |
" else:\n", | |
" # construct weights and biases\n", | |
" weights, biases = [], []\n", | |
" for i in range(len(self.cutoffs)):\n", | |
" if self.div_val == 1:\n", | |
" l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]\n", | |
" weight_i = self.out_layers[0].weight[l_idx:r_idx]\n", | |
" bias_i = self.out_layers[0].bias[l_idx:r_idx]\n", | |
" else:\n", | |
" weight_i = self.out_layers[i].weight\n", | |
" bias_i = self.out_layers[i].bias\n", | |
"\n", | |
" if i == 0:\n", | |
" #add cols for non-head clusters\n", | |
" weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)\n", | |
" bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)\n", | |
"\n", | |
" weights.append(weight_i)\n", | |
" biases.append(bias_i)\n", | |
"\n", | |
" head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]\n", | |
" head_logit = self._compute_logit(input, head_weight, head_bias, head_proj)\n", | |
" head_logprob = F.log_softmax(head_logit, dim=-1)\n", | |
"\n", | |
" out = input.new_empty((input.size(0), input.size(1), self.n_token))\n", | |
" head_logprob = F.log_softmax(head_logit, dim=-1)\n", | |
" out[:, :, :self.shortlist_size] = head_logprob[:, :, :self.shortlist_size] #[bz,bptt,head_size]\n", | |
"\n", | |
" for i, (start_idx, stop_idx) in enumerate(zip(self.cutoff_ends, self.cutoff_ends[1:])):\n", | |
" if i > 0:\n", | |
" weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]\n", | |
" tail_logit_i = self._compute_logit(input,weight_i,bias_i,proj_i)\n", | |
" tail_logprob_i = F.log_softmax(tail_logit_i, dim=-1)\n", | |
" logprob_i = head_logprob[:,:,-i].unsqueeze(-1) + tail_logprob_i\n", | |
" out[:, :, start_idx:stop_idx] = logprob_i\n", | |
"\n", | |
" return out #logprob output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn = language_model_learner(data, AWD_LSTM,drop_mult=0.5,pretrained=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab_sz = len(learn.data.vocab.itos)\n", | |
"cutoffs = [round(vocab_sz/15), 3*round(vocab_sz/15)] #you can set these up to any groupings you would like." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AdaptSoftMaxLayer(nn.Module):\n", | |
" \"AdaptiveSoftMaxLayer on top of a RNNCore module and create a Language Model.\"\n", | |
" def __init__(self, vocab_sz:int, n_hid:int,tie_encoder:nn.Module=None, cutoffs:List[int]=[],div_value=4.):\n", | |
" super().__init__()\n", | |
" self.decoder = ProjectedAdaptiveLogSoftmax(vocab_sz, n_hid, n_hid, cutoffs, div_val=div_value)\n", | |
" if tie_encoder:\n", | |
" for out_l,emb_l in zip(self.decoder.out_layers, tie_encoder.emb_layers):\n", | |
" out_l.weight = emb_l.weight\n", | |
" for out_p,emb_p in zip(self.decoder.out_projs, tie_encoder.emb_projs):\n", | |
" out_p = emb_p\n", | |
" \n", | |
" def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:\n", | |
" raw_outputs,outputs = input\n", | |
" decoded = self.decoder(outputs[-1])\n", | |
"\n", | |
" return decoded, raw_outputs, outputs" | |
] | |
}, | |
{ | |
"cell_type": "raw", | |
"metadata": {}, | |
"source": [ | |
"# this has to be changed in fastai/text/models/awd_lstm.py till I get to fixing it myself\n", | |
"# there is no embedding dropout and if you override the encoder, it will not pass along.\n", | |
" self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)\n", | |
"-->>> self.encoder_dp = self.encoder #EmbeddingDropout(self.encoder, embed_p) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"enc_layer = ProjectedAdaptiveEmbedding(vocab_sz, 400, 400, cutoffs, div_val=4, sample_softmax=False)\n", | |
"#learn.model[0].encoder = to_device(enc_layer,defaults.device)\n", | |
"learn.model[0].encoder_dp = to_device(enc_layer,defaults.device)\n", | |
"\n", | |
"decode_layer = AdaptSoftMaxLayer(vocab_sz,400,cutoffs=cutoffs,tie_encoder=learn.model[0].encoder_dp)\n", | |
"learn.model[1] = to_device(decode_layer,defaults.device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"SequentialRNN(\n", | |
" (0): AWD_LSTM(\n", | |
" (encoder): Embedding(8920, 400, padding_idx=1)\n", | |
" (encoder_dp): ProjectedAdaptiveEmbedding(\n", | |
" (emb_layers): ModuleList(\n", | |
" (0): Embedding(595, 400)\n", | |
" (1): Embedding(1190, 100)\n", | |
" (2): Embedding(7135, 25)\n", | |
" )\n", | |
" (emb_projs): ParameterList(\n", | |
" (0): Parameter containing: [torch.cuda.FloatTensor of size 400x400 (GPU 0)]\n", | |
" (1): Parameter containing: [torch.cuda.FloatTensor of size 400x100 (GPU 0)]\n", | |
" (2): Parameter containing: [torch.cuda.FloatTensor of size 400x25 (GPU 0)]\n", | |
" )\n", | |
" )\n", | |
" (rnns): ModuleList(\n", | |
" (0): WeightDropout(\n", | |
" (module): LSTM(400, 1150, batch_first=True)\n", | |
" )\n", | |
" (1): WeightDropout(\n", | |
" (module): LSTM(1150, 1150, batch_first=True)\n", | |
" )\n", | |
" (2): WeightDropout(\n", | |
" (module): LSTM(1150, 400, batch_first=True)\n", | |
" )\n", | |
" )\n", | |
" (input_dp): RNNDropout()\n", | |
" (hidden_dps): ModuleList(\n", | |
" (0): RNNDropout()\n", | |
" (1): RNNDropout()\n", | |
" (2): RNNDropout()\n", | |
" )\n", | |
" )\n", | |
" (1): AdaptSoftMaxLayer(\n", | |
" (decoder): ProjectedAdaptiveLogSoftmax(\n", | |
" (out_layers): ModuleList(\n", | |
" (0): Linear(in_features=400, out_features=595, bias=True)\n", | |
" (1): Linear(in_features=100, out_features=1190, bias=True)\n", | |
" (2): Linear(in_features=25, out_features=7135, bias=True)\n", | |
" )\n", | |
" (out_projs): ParameterList(\n", | |
" (0): Parameter containing: [torch.cuda.FloatTensor of size 400x400 (GPU 0)]\n", | |
" (1): Parameter containing: [torch.cuda.FloatTensor of size 400x100 (GPU 0)]\n", | |
" (2): Parameter containing: [torch.cuda.FloatTensor of size 400x25 (GPU 0)]\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"learn.model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"Total time: 00:44 <p><table style='width:300px; margin-bottom:10px'>\n", | |
" <tr>\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <th>7.483484</th>\n", | |
" <th>6.249792</th>\n", | |
" <th>0.167092</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <th>6.734814</th>\n", | |
" <th>6.044311</th>\n", | |
" <th>0.183339</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <th>6.430354</th>\n", | |
" <th>5.918349</th>\n", | |
" <th>0.196476</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <th>6.253704</th>\n", | |
" <th>5.866490</th>\n", | |
" <th>0.203858</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <th>6.136080</th>\n", | |
" <th>5.829422</th>\n", | |
" <th>0.203141</th>\n", | |
" </tr>\n", | |
"</table>\n" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"learn.fit(5)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.7 fasta.ai1 DEV", | |
"language": "python", | |
"name": "fastai1_dev" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.1" | |
}, | |
"varInspector": { | |
"cols": { | |
"lenName": 16, | |
"lenType": 16, | |
"lenVar": 40 | |
}, | |
"kernels_config": { | |
"python": { | |
"delete_cmd_postfix": "", | |
"delete_cmd_prefix": "del ", | |
"library": "var_list.py", | |
"varRefreshCmd": "print(var_dic_list())" | |
}, | |
"r": { | |
"delete_cmd_postfix": ") ", | |
"delete_cmd_prefix": "rm(", | |
"library": "var_list.r", | |
"varRefreshCmd": "cat(var_dic_list()) " | |
} | |
}, | |
"types_to_exclude": [ | |
"module", | |
"function", | |
"builtin_function_or_method", | |
"instance", | |
"_Feature" | |
], | |
"window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment