Last active
October 10, 2020 16:02
-
-
Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Fast.ai p1v1: class 4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Full deep learning model " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%reload_ext autoreload\n", | |
| "%autoreload 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from fastai import *\n", | |
| "from fastai.tabular import *\n", | |
| "from fastai.text import *\n", | |
| "from fastai.metrics import accuracy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "valid_sz = 10000\n", | |
| "PATH = Path('~/data/').expanduser()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "df = pd.read_feather(PATH/'listings-df')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 1. Input models" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1.1. Tabular model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "df_tab = df.drop('title', axis=1) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "cont_cols = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6',\n", | |
| " 'col7', 'col8', 'col9', 'col10', 'col11', 'col12'\n", | |
| " 'title_isnew_prob'] # real columns names were replaced\n", | |
| "cat_cols = sorted(list(set(df_tab.columns) - set(cont_cols) - {'condition'}))\n", | |
| "valid_idx = range(len(df)-valid_sz, len(df))\n", | |
| "procs = [FillMissing, Categorify, Normalize]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "data_tab = (TabularList.from_df(df_tab, cat_cols, cont_cols, procs=procs, path=PATH)\n", | |
| " .split_by_idx(valid_idx)\n", | |
| " .label_from_df(cols='condition')\n", | |
| " .databunch())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn_tab = tabular_learner(data_tab, layers=[64], ps=[0.5], emb_drop=0.05, metrics=accuracy)\n", | |
| "learn_tab.load('tabular-model');" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "TabularModel(\n", | |
| " (embeds): ModuleList(\n", | |
| " (0): Embedding(4, 3)\n", | |
| " (1): Embedding(10492, 50)\n", | |
| " (2): Embedding(3, 2)\n", | |
| " (3): Embedding(8, 5)\n", | |
| " (4): Embedding(1461, 50)\n", | |
| " (5): Embedding(286, 50)\n", | |
| " (6): Embedding(3481, 50)\n", | |
| " (7): Embedding(304, 50)\n", | |
| " (8): Embedding(570, 50)\n", | |
| " (9): Embedding(30, 16)\n", | |
| " (10): Embedding(26, 14)\n", | |
| " (11): Embedding(300, 50)\n", | |
| " (12): Embedding(33283, 50)\n", | |
| " (13): Embedding(5, 3)\n", | |
| " (14): Embedding(5, 3)\n", | |
| " (15): Embedding(3, 2)\n", | |
| " (16): Embedding(3, 2)\n", | |
| " )\n", | |
| " (emb_drop): Dropout(p=0.05)\n", | |
| " (bn_cont): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
| " (layers): Sequential(\n", | |
| " (0): Linear(in_features=462, out_features=64, bias=True)\n", | |
| " (1): ReLU(inplace)\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "learn_tab.model.layers = learn_tab.model.layers[:-3]\n", | |
| "learn_tab.model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1.2. NLP Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "df_nlp = df[['title', 'condition']]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "vocab = pickle.load(open(PATH/'itos', 'rb'))\n", | |
| "data_nlp = TextClasDataBunch.from_df(PATH, df_nlp[:-valid_sz], df_nlp[-valid_sz:], \n", | |
| " tokenizer=Tokenizer(lang='es'), \n", | |
| " vocab=vocab, text_cols='title', label_cols='condition')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn_nlp = text_classifier_learner(data_nlp, drop_mult=0.5)\n", | |
| "learn_nlp.load('nlp-final');" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "SequentialRNN(\n", | |
| " (0): MultiBatchRNNCore(\n", | |
| " (encoder): Embedding(22847, 400, padding_idx=1)\n", | |
| " (encoder_dp): EmbeddingDropout(\n", | |
| " (emb): Embedding(22847, 400, padding_idx=1)\n", | |
| " )\n", | |
| " (rnns): ModuleList(\n", | |
| " (0): WeightDropout(\n", | |
| " (module): LSTM(400, 1150)\n", | |
| " )\n", | |
| " (1): WeightDropout(\n", | |
| " (module): LSTM(1150, 1150)\n", | |
| " )\n", | |
| " (2): WeightDropout(\n", | |
| " (module): LSTM(1150, 400)\n", | |
| " )\n", | |
| " )\n", | |
| " (input_dp): RNNDropout()\n", | |
| " (hidden_dps): ModuleList(\n", | |
| " (0): RNNDropout()\n", | |
| " (1): RNNDropout()\n", | |
| " (2): RNNDropout()\n", | |
| " )\n", | |
| " )\n", | |
| " (1): PoolingLinearClassifier(\n", | |
| " (layers): Sequential(\n", | |
| " (0): BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
| " (1): Dropout(p=0.2)\n", | |
| " (2): Linear(in_features=1200, out_features=50, bias=True)\n", | |
| " (3): ReLU(inplace)\n", | |
| " )\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "learn_nlp.model[-1].layers = learn_nlp.model[-1].layers[:-3] \n", | |
| "learn_nlp.model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# 2. Concat model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConcatDataset(Dataset):\n", | |
| " def __init__(self, x1, x2, y): self.x1,self.x2,self.y = x1,x2,y\n", | |
| " def __len__(self): return len(self.y)\n", | |
| " def __getitem__(self, i): return (self.x1[i], self.x2[i]), self.y[i]\n", | |
| "\n", | |
| "train_ds = ConcatDataset(data_tab.train_ds.x, data_nlp.train_ds.x, data_tab.train_ds.y)\n", | |
| "valid_ds = ConcatDataset(data_tab.valid_ds.x, data_nlp.valid_ds.x, data_tab.valid_ds.y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def my_collate(batch): \n", | |
| " x,y = list(zip(*batch))\n", | |
| " x1,x2 = list(zip(*x))\n", | |
| " x1 = to_data(x1)\n", | |
| " x1 = list(zip(*x1))\n", | |
| " x1 = torch.stack(x1[0]), torch.stack(x1[1])\n", | |
| " x2, y = pad_collate(list(zip(x2, y)), pad_idx=1, pad_first=True)\n", | |
| " return (x1, x2), y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "bs = 64\n", | |
| "train_sampler = SortishSampler(data_nlp.train_ds.x, key=lambda t: len(data_nlp.train_ds[t][0].data), bs=bs//2)\n", | |
| "valid_sampler = SortSampler(data_nlp.valid_ds.x, key=lambda t: len(data_nlp.valid_ds[t][0].data))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, bs//2, sampler=train_sampler)\n", | |
| "valid_dl = DataLoader(valid_ds, bs, sampler=valid_sampler)\n", | |
| "data = DataBunch(train_dl, valid_dl, device=defaults.device, collate_fn=my_collate, path=PATH)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Shape tabular batch (cats/cont): torch.Size([32, 17]) / torch.Size([32, 12])\n", | |
| "Shape nlp batch: torch.Size([42, 32])\n", | |
| "Shape dependent var: torch.Size([32])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "(x1,x2),y = next(iter(data.train_dl))\n", | |
| "print(f'Shape tabular batch (cats/cont): {x1[0].shape} / {x1[1].shape}')\n", | |
| "print(f'Shape nlp batch: {x2.shape}')\n", | |
| "print(f'Shape dependent var: {y.shape}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class ConcatModel(nn.Module):\n", | |
| " def __init__(self, mod_tab, mod_nlp, layers, drops): \n", | |
| " super().__init__()\n", | |
| " self.mod_tab = mod_tab\n", | |
| " self.mod_nlp = mod_nlp\n", | |
| " lst_layers = []\n", | |
| " activs = [nn.ReLU(inplace=True),] * (len(layers)-2) + [None]\n", | |
| " for n_in,n_out,p,actn in zip(layers[:-1], layers[1:], drops, activs):\n", | |
| " lst_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)\n", | |
| " self.layers = nn.Sequential(*lst_layers)\n", | |
| "\n", | |
| " def forward(self, *x):\n", | |
| " x_tab = self.mod_tab(*x[0])\n", | |
| " x_nlp = self.mod_nlp(x[1])[0]\n", | |
| " x = torch.cat([x_tab, x_nlp], dim=1)\n", | |
| " return self.layers(x) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "ConcatModel(\n", | |
| " (mod_tab): TabularModel(\n", | |
| " (embeds): ModuleList(\n", | |
| " (0): Embedding(4, 3)\n", | |
| " (1): Embedding(10492, 50)\n", | |
| " (2): Embedding(3, 2)\n", | |
| " (3): Embedding(8, 5)\n", | |
| " (4): Embedding(1461, 50)\n", | |
| " (5): Embedding(286, 50)\n", | |
| " (6): Embedding(3481, 50)\n", | |
| " (7): Embedding(304, 50)\n", | |
| " (8): Embedding(570, 50)\n", | |
| " (9): Embedding(30, 16)\n", | |
| " (10): Embedding(26, 14)\n", | |
| " (11): Embedding(300, 50)\n", | |
| " (12): Embedding(33283, 50)\n", | |
| " (13): Embedding(5, 3)\n", | |
| " (14): Embedding(5, 3)\n", | |
| " (15): Embedding(3, 2)\n", | |
| " (16): Embedding(3, 2)\n", | |
| " )\n", | |
| " (emb_drop): Dropout(p=0.05)\n", | |
| " (bn_cont): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
| " (layers): Sequential(\n", | |
| " (0): Linear(in_features=462, out_features=64, bias=True)\n", | |
| " (1): ReLU(inplace)\n", | |
| " )\n", | |
| " )\n", | |
| " (mod_nlp): SequentialRNN(\n", | |
| " (0): MultiBatchRNNCore(\n", | |
| " (encoder): Embedding(22847, 400, padding_idx=1)\n", | |
| " (encoder_dp): EmbeddingDropout(\n", | |
| " (emb): Embedding(22847, 400, padding_idx=1)\n", | |
| " )\n", | |
| " (rnns): ModuleList(\n", | |
| " (0): WeightDropout(\n", | |
| " (module): LSTM(400, 1150)\n", | |
| " )\n", | |
| " (1): WeightDropout(\n", | |
| " (module): LSTM(1150, 1150)\n", | |
| " )\n", | |
| " (2): WeightDropout(\n", | |
| " (module): LSTM(1150, 400)\n", | |
| " )\n", | |
| " )\n", | |
| " (input_dp): RNNDropout()\n", | |
| " (hidden_dps): ModuleList(\n", | |
| " (0): RNNDropout()\n", | |
| " (1): RNNDropout()\n", | |
| " (2): RNNDropout()\n", | |
| " )\n", | |
| " )\n", | |
| " (1): PoolingLinearClassifier(\n", | |
| " (layers): Sequential(\n", | |
| " (0): BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
| " (1): Dropout(p=0.2)\n", | |
| " (2): Linear(in_features=1200, out_features=50, bias=True)\n", | |
| " (3): ReLU(inplace)\n", | |
| " )\n", | |
| " )\n", | |
| " )\n", | |
| " (layers): Sequential(\n", | |
| " (0): BatchNorm1d(114, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
| " (1): Dropout(p=0.8)\n", | |
| " (2): Linear(in_features=114, out_features=2, bias=True)\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lin_layers = [64+50, 2]\n", | |
| "ps = [0.8]\n", | |
| "model = ConcatModel(learn_tab.model, learn_nlp.model, lin_layers, ps)\n", | |
| "model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Learner" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "loss_func = nn.CrossEntropyLoss()\n", | |
| "layer_groups = [nn.Sequential(*flatten_model(learn_nlp.layer_groups[0])),\n", | |
| " nn.Sequential(*flatten_model(learn_nlp.layer_groups[1])),\n", | |
| " nn.Sequential(*flatten_model(learn_nlp.layer_groups[2])),\n", | |
| " nn.Sequential(*flatten_model(learn_nlp.layer_groups[3])),\n", | |
| " nn.Sequential(*(flatten_model(learn_nlp.layer_groups[4]) + \n", | |
| " flatten_model(model.mod_tab) +\n", | |
| " flatten_model(model.layers)))] \n", | |
| "learn = Learner(data, model, loss_func=loss_func, metrics=accuracy, layer_groups=layer_groups)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Train!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learn.freeze()\n", | |
| "learn.lr_find()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "learn.recorder.plot()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Total time: 00:37\n", | |
| "epoch train_loss valid_loss accuracy\n", | |
| "1 0.106572 0.248390 0.920200 (00:37)\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learn.fit_one_cycle(1, 1e-2, moms=(0.8, 0.7))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Total time: 00:40\n", | |
| "epoch train_loss valid_loss accuracy\n", | |
| "1 0.086336 0.256554 0.919800 (00:40)\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learn.freeze_to(-2)\n", | |
| "learn.fit_one_cycle(1, slice(5e-3/(2.6**4), 5e-3), moms=(0.8, 0.7))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Total time: 01:03\n", | |
| "epoch train_loss valid_loss accuracy\n", | |
| "1 0.097170 0.257217 0.919500 (01:03)\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learn.freeze_to(-3)\n", | |
| "learn.fit_one_cycle(1, slice(2e-3/(2.6**4), 2e-3), moms=(0.8, 0.7))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Total time: 07:04\n", | |
| "epoch train_loss valid_loss accuracy\n", | |
| "1 0.080045 0.260310 0.920200 (01:24)\n", | |
| "2 0.075644 0.249944 0.922800 (01:26)\n", | |
| "3 0.071381 0.271557 0.920900 (01:26)\n", | |
| "4 0.078788 0.290130 0.919600 (01:24)\n", | |
| "5 0.088786 0.268973 0.921800 (01:23)\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learn.unfreeze()\n", | |
| "learn.fit_one_cycle(5, slice(5e-4/(2.6**4), 5e-4), moms=(0.8, 0.7))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Total time: 06:51\n", | |
| "epoch train_loss valid_loss accuracy\n", | |
| "1 0.077082 0.248748 0.924100 (01:21)\n", | |
| "2 0.081846 0.249953 0.923700 (01:20)\n", | |
| "3 0.088959 0.254498 0.920200 (01:23)\n", | |
| "4 0.056842 0.249644 0.922800 (01:21)\n", | |
| "5 0.067153 0.244735 0.922900 (01:24)\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "learn.fit_one_cycle(5, slice(5e-4/(2.6**4), 5e-4), moms=(0.8, 0.7), wd=1e-1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python [conda env:fastai]", | |
| "language": "python", | |
| "name": "conda-env-fastai-py" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.6" | |
| }, | |
| "varInspector": { | |
| "cols": { | |
| "lenName": 16, | |
| "lenType": 16, | |
| "lenVar": 40 | |
| }, | |
| "kernels_config": { | |
| "python": { | |
| "delete_cmd_postfix": "", | |
| "delete_cmd_prefix": "del ", | |
| "library": "var_list.py", | |
| "varRefreshCmd": "print(var_dic_list())" | |
| }, | |
| "r": { | |
| "delete_cmd_postfix": ") ", | |
| "delete_cmd_prefix": "rm(", | |
| "library": "var_list.r", | |
| "varRefreshCmd": "cat(var_dic_list()) " | |
| } | |
| }, | |
| "types_to_exclude": [ | |
| "module", | |
| "function", | |
| "builtin_function_or_method", | |
| "instance", | |
| "_Feature" | |
| ], | |
| "window_display": false | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Author
Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'
Best regards
Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'Best regards
Same problem here
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
make gist public