Skip to content

Instantly share code, notes, and snippets.

@koreyou
Created October 29, 2017 13:35
Show Gist options
  • Save koreyou/a91189b53c575dd722a767ecfa9fc5e9 to your computer and use it in GitHub Desktop.
Save koreyou/a91189b53c575dd722a767ecfa9fc5e9 to your computer and use it in GitHub Desktop.
Implementation of early stopping for Chainer
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"import shutil\n",
"import tempfile\n",
"\n",
"import chainer\n",
"from chainer.training import extension\n",
"\n",
"\n",
"def _snapshot_object(trainer, target, filename, savefun):\n",
" fd, tmppath = tempfile.mkstemp()\n",
" try:\n",
" savefun(tmppath, target)\n",
" except Exception:\n",
" os.close(fd)\n",
" os.remove(tmppath)\n",
" raise\n",
" os.close(fd)\n",
" shutil.move(tmppath, filename)\n",
"\n",
"\n",
"class SaveRestore(chainer.training.extension.Extension):\n",
"\n",
" \"\"\"Trainer extension to save a snapshot and restore it at the end of\n",
" training.\n",
"\n",
" Typical usage is:\n",
"\n",
" .. code-block:: python\n",
"\n",
" trainer.extend(\n",
" SaveRestore(),\n",
" trigger=chainer.training.triggers.MinValueTrigger('validation/main/loss'))\n",
"\n",
" which save will save snapshots and apply (pseudo-) early stopping by\n",
" loading the snapshot with the best validation loss.\n",
"\n",
" Args:\n",
" filename (str): Name of the file into which the object is serialized.\n",
" It can be a format string, where the trainer object is passed to\n",
" the :meth:`str.format` method. For example,\n",
" ``'snapshot_{.updater.iteration}'`` is converted to\n",
" ``'snapshot_10000'`` at the 10,000th iteration.\n",
" Or you can give name without formatter, which will overwrite the\n",
" saved object on each call, thus only keeping the best model on\n",
" the disk.\n",
" Or you can give None, in which case the object is saved to\n",
" a temporaly path and deleted at the end of the training.\n",
" savefun: Function to save the object. It takes two arguments: the\n",
" output file path and the object to serialize.\n",
" loadfun: Function to load the object. It takes two arguments: the\n",
" file path and the object to deserialize.\n",
" \"\"\"\n",
" priority = -100\n",
"\n",
" def __init__(self, filename='snapshot_iter_{.updater.iteration}',\n",
" savefun=chainer.serializers.npz.save_npz,\n",
" loadfun=chainer.serializers.npz.load_npz):\n",
" super(SaveRestore, self).__init__()\n",
" self._savefun = savefun\n",
" self._loadfun = loadfun\n",
" self._saved_iteration = None\n",
" self._keep_snapshot = filename is not None\n",
" self._filename = filename or 'saverestore' + str(hash(random.random()))\n",
"\n",
" def __call__(self, trainer):\n",
" fn = self._filename.format(trainer)\n",
" self._saved_path = os.path.join(trainer.out, fn)\n",
" if not os.path.exists(trainer.out):\n",
" os.makedirs(trainer.out) \n",
" _snapshot_object(trainer, trainer, self._saved_path, self._savefun)\n",
" self._saved_iteration = trainer.updater.iteration\n",
" self._trainer = trainer # get referencee to trainer\n",
"\n",
" def finalize(self):\n",
" if self._saved_iteration is not None:\n",
" print('Loading model from %d iteration' % self._saved_iteration)\n",
" self._loadfun(self._saved_path, self._trainer)\n",
" else:\n",
" print('Warning: SaveRestore was never triggered')\n",
" if not self._keep_snapshot and os.path.exists(self._saved_path):\n",
" os.remove(self._saved_path)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time\n",
"\u001b[J1875 0.480014 0.307354 0.8797 0.9168 7.68877 \n",
"\u001b[J3750 0.303195 0.284905 0.91615 0.9215 15.8067 \n",
"\u001b[J5625 0.282889 0.278314 0.921 0.9229 23.1525 \n",
"\u001b[J7500 0.27267 0.268663 0.924017 0.9259 30.866 \n",
"\u001b[J9375 0.26581 0.267392 0.925883 0.9257 38.6127 \n",
"\u001b[J11250 0.261389 0.263896 0.92705 0.9255 45.3136 \n",
"\u001b[J13125 0.258034 0.265104 0.92865 0.9262 52.9158 \n",
"\u001b[J15000 0.255451 0.268626 0.9289 0.9256 59.4307 \n",
"\u001b[J16875 0.252378 0.267866 0.930583 0.9261 65.8631 \n",
"\u001b[J18750 0.250997 0.265482 0.930683 0.9275 72.0764 \n",
"\u001b[J20625 0.249587 0.265421 0.931283 0.9277 78.8752 \n",
"\u001b[J22500 0.247599 0.270085 0.931817 0.9254 87.0834 \n",
"\u001b[J24375 0.246133 0.263291 0.932367 0.9275 99.086 \n",
"\u001b[J26250 0.244599 0.264455 0.932967 0.9282 111.046 \n",
"\u001b[J28125 0.243589 0.267765 0.932033 0.9283 124.376 \n",
"\u001b[J30000 0.242842 0.268463 0.933083 0.9255 133.075 \n",
"\u001b[J31875 0.24204 0.26628 0.9336 0.9274 140.766 \n",
"\u001b[J33750 0.241006 0.266837 0.933417 0.9273 148.613 \n",
"\u001b[J35625 0.240575 0.271143 0.933567 0.9258 155.746 \n",
"\u001b[J37500 0.239719 0.266733 0.933717 0.9279 168.479 \n",
"Loading model from 24375 iteration\n"
]
}
],
"source": [
"import chainer\n",
"import chainer.functions as F\n",
"import chainer.links as L\n",
"from chainer import training\n",
"from chainer.training import extensions\n",
"\n",
"# Network definition\n",
"class MLP(chainer.Chain):\n",
" def __init__(self, n_out):\n",
" super(MLP, self).__init__()\n",
" with self.init_scope():\n",
" self.l1 = L.Linear(None, n_out)\n",
"\n",
" def __call__(self, x):\n",
" return self.l1(x)\n",
"\n",
"model = L.Classifier(MLP(50))\n",
"\n",
"# Setup an optimizer\n",
"optimizer = chainer.optimizers.Adam()\n",
"optimizer.setup(model)\n",
"\n",
"# Load the MNIST dataset\n",
"train, dev = chainer.datasets.get_mnist()\n",
"\n",
"train_iter = chainer.iterators.SerialIterator(train, 32)\n",
"\n",
"# Set up a trainer\n",
"updater = training.StandardUpdater(train_iter, optimizer)\n",
"trainer = training.Trainer(updater, (20, 'epoch'), out='result')\n",
"dev_iter = chainer.iterators.SerialIterator(dev, 200, repeat=False, shuffle=False)\n",
"\n",
"# Evaluate the model with the test dataset for each epoch\n",
"trainer.extend(extensions.Evaluator(dev_iter, model))\n",
"\n",
"# Write a log of evaluation statistics for each epoch\n",
"trainer.extend(extensions.LogReport())\n",
"trainer.extend(SaveRestore(),\n",
" trigger=chainer.training.triggers.MinValueTrigger('validation/main/loss'))\n",
"trainer.extend(extensions.PrintReport(\n",
" ['iteration', 'main/loss', 'validation/main/loss',\n",
" 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))\n",
"\n",
"# Run the training\n",
"trainer.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment