Created
October 29, 2017 13:35
-
-
Save koreyou/a91189b53c575dd722a767ecfa9fc5e9 to your computer and use it in GitHub Desktop.
Implementation of early stopping for Chainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import random\n", | |
"import shutil\n", | |
"import tempfile\n", | |
"\n", | |
"import chainer\n", | |
"from chainer.training import extension\n", | |
"\n", | |
"\n", | |
"def _snapshot_object(trainer, target, filename, savefun):\n", | |
" fd, tmppath = tempfile.mkstemp()\n", | |
" try:\n", | |
" savefun(tmppath, target)\n", | |
" except Exception:\n", | |
" os.close(fd)\n", | |
" os.remove(tmppath)\n", | |
" raise\n", | |
" os.close(fd)\n", | |
" shutil.move(tmppath, filename)\n", | |
"\n", | |
"\n", | |
"class SaveRestore(chainer.training.extension.Extension):\n", | |
"\n", | |
" \"\"\"Trainer extension to save a snapshot and restore it at the end of\n", | |
" training.\n", | |
"\n", | |
" Typical usage is:\n", | |
"\n", | |
" .. code-block:: python\n", | |
"\n", | |
" trainer.extend(\n", | |
" SaveRestore(),\n", | |
" trigger=chainer.training.triggers.MinValueTrigger('validation/main/loss'))\n", | |
"\n", | |
" which save will save snapshots and apply (pseudo-) early stopping by\n", | |
" loading the snapshot with the best validation loss.\n", | |
"\n", | |
" Args:\n", | |
" filename (str): Name of the file into which the object is serialized.\n", | |
" It can be a format string, where the trainer object is passed to\n", | |
" the :meth:`str.format` method. For example,\n", | |
" ``'snapshot_{.updater.iteration}'`` is converted to\n", | |
" ``'snapshot_10000'`` at the 10,000th iteration.\n", | |
" Or you can give name without formatter, which will overwrite the\n", | |
" saved object on each call, thus only keeping the best model on\n", | |
" the disk.\n", | |
" Or you can give None, in which case the object is saved to\n", | |
" a temporaly path and deleted at the end of the training.\n", | |
" savefun: Function to save the object. It takes two arguments: the\n", | |
" output file path and the object to serialize.\n", | |
" loadfun: Function to load the object. It takes two arguments: the\n", | |
" file path and the object to deserialize.\n", | |
" \"\"\"\n", | |
" priority = -100\n", | |
"\n", | |
" def __init__(self, filename='snapshot_iter_{.updater.iteration}',\n", | |
" savefun=chainer.serializers.npz.save_npz,\n", | |
" loadfun=chainer.serializers.npz.load_npz):\n", | |
" super(SaveRestore, self).__init__()\n", | |
" self._savefun = savefun\n", | |
" self._loadfun = loadfun\n", | |
" self._saved_iteration = None\n", | |
" self._keep_snapshot = filename is not None\n", | |
" self._filename = filename or 'saverestore' + str(hash(random.random()))\n", | |
"\n", | |
" def __call__(self, trainer):\n", | |
" fn = self._filename.format(trainer)\n", | |
" self._saved_path = os.path.join(trainer.out, fn)\n", | |
" if not os.path.exists(trainer.out):\n", | |
" os.makedirs(trainer.out) \n", | |
" _snapshot_object(trainer, trainer, self._saved_path, self._savefun)\n", | |
" self._saved_iteration = trainer.updater.iteration\n", | |
" self._trainer = trainer # get referencee to trainer\n", | |
"\n", | |
" def finalize(self):\n", | |
" if self._saved_iteration is not None:\n", | |
" print('Loading model from %d iteration' % self._saved_iteration)\n", | |
" self._loadfun(self._saved_path, self._trainer)\n", | |
" else:\n", | |
" print('Warning: SaveRestore was never triggered')\n", | |
" if not self._keep_snapshot and os.path.exists(self._saved_path):\n", | |
" os.remove(self._saved_path)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"iteration main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time\n", | |
"\u001b[J1875 0.480014 0.307354 0.8797 0.9168 7.68877 \n", | |
"\u001b[J3750 0.303195 0.284905 0.91615 0.9215 15.8067 \n", | |
"\u001b[J5625 0.282889 0.278314 0.921 0.9229 23.1525 \n", | |
"\u001b[J7500 0.27267 0.268663 0.924017 0.9259 30.866 \n", | |
"\u001b[J9375 0.26581 0.267392 0.925883 0.9257 38.6127 \n", | |
"\u001b[J11250 0.261389 0.263896 0.92705 0.9255 45.3136 \n", | |
"\u001b[J13125 0.258034 0.265104 0.92865 0.9262 52.9158 \n", | |
"\u001b[J15000 0.255451 0.268626 0.9289 0.9256 59.4307 \n", | |
"\u001b[J16875 0.252378 0.267866 0.930583 0.9261 65.8631 \n", | |
"\u001b[J18750 0.250997 0.265482 0.930683 0.9275 72.0764 \n", | |
"\u001b[J20625 0.249587 0.265421 0.931283 0.9277 78.8752 \n", | |
"\u001b[J22500 0.247599 0.270085 0.931817 0.9254 87.0834 \n", | |
"\u001b[J24375 0.246133 0.263291 0.932367 0.9275 99.086 \n", | |
"\u001b[J26250 0.244599 0.264455 0.932967 0.9282 111.046 \n", | |
"\u001b[J28125 0.243589 0.267765 0.932033 0.9283 124.376 \n", | |
"\u001b[J30000 0.242842 0.268463 0.933083 0.9255 133.075 \n", | |
"\u001b[J31875 0.24204 0.26628 0.9336 0.9274 140.766 \n", | |
"\u001b[J33750 0.241006 0.266837 0.933417 0.9273 148.613 \n", | |
"\u001b[J35625 0.240575 0.271143 0.933567 0.9258 155.746 \n", | |
"\u001b[J37500 0.239719 0.266733 0.933717 0.9279 168.479 \n", | |
"Loading model from 24375 iteration\n" | |
] | |
} | |
], | |
"source": [ | |
"import chainer\n", | |
"import chainer.functions as F\n", | |
"import chainer.links as L\n", | |
"from chainer import training\n", | |
"from chainer.training import extensions\n", | |
"\n", | |
"# Network definition\n", | |
"class MLP(chainer.Chain):\n", | |
" def __init__(self, n_out):\n", | |
" super(MLP, self).__init__()\n", | |
" with self.init_scope():\n", | |
" self.l1 = L.Linear(None, n_out)\n", | |
"\n", | |
" def __call__(self, x):\n", | |
" return self.l1(x)\n", | |
"\n", | |
"model = L.Classifier(MLP(50))\n", | |
"\n", | |
"# Setup an optimizer\n", | |
"optimizer = chainer.optimizers.Adam()\n", | |
"optimizer.setup(model)\n", | |
"\n", | |
"# Load the MNIST dataset\n", | |
"train, dev = chainer.datasets.get_mnist()\n", | |
"\n", | |
"train_iter = chainer.iterators.SerialIterator(train, 32)\n", | |
"\n", | |
"# Set up a trainer\n", | |
"updater = training.StandardUpdater(train_iter, optimizer)\n", | |
"trainer = training.Trainer(updater, (20, 'epoch'), out='result')\n", | |
"dev_iter = chainer.iterators.SerialIterator(dev, 200, repeat=False, shuffle=False)\n", | |
"\n", | |
"# Evaluate the model with the test dataset for each epoch\n", | |
"trainer.extend(extensions.Evaluator(dev_iter, model))\n", | |
"\n", | |
"# Write a log of evaluation statistics for each epoch\n", | |
"trainer.extend(extensions.LogReport())\n", | |
"trainer.extend(SaveRestore(),\n", | |
" trigger=chainer.training.triggers.MinValueTrigger('validation/main/loss'))\n", | |
"trainer.extend(extensions.PrintReport(\n", | |
" ['iteration', 'main/loss', 'validation/main/loss',\n", | |
" 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))\n", | |
"\n", | |
"# Run the training\n", | |
"trainer.run()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment