Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save braingineer/9b8205b16c6c2fa50a2afd85a5448742 to your computer and use it in GitHub Desktop.
Save braingineer/9b8205b16c6c2fa50a2afd85a5448742 to your computer and use it in GitHub Desktop.
registering hooks in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"demonstrating the hook registering.\n",
"\n",
"mostly copy-pasted from http://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear (10 -> 20)\n",
"Linear (20 -> 30)\n",
"Inside Linear forward\n",
"-\n",
"type(input_i): <class 'torch.autograd.variable.Variable'>\n",
"input_i.size(): torch.Size([8, 10])\n",
"-\n",
"output_i size: torch.Size([8, 20])\n",
"output_i norm: 7.363391333259189\n",
"==\n",
"\n",
"Inside Linear forward\n",
"-\n",
"type(input_i): <class 'torch.autograd.variable.Variable'>\n",
"input_i.size(): torch.Size([8, 20])\n",
"-\n",
"output_i size: torch.Size([8, 30])\n",
"output_i norm: 5.15335982781888\n",
"==\n",
"\n",
"Inside Linear backward\n",
"Inside class:Linear\n",
"-\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([30])\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([8, 20])\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([20, 30])\n",
"-\n",
"type(grad_output_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_output_i.size(): torch.Size([8, 30])\n",
"==\n",
"\n",
"Inside Linear backward\n",
"Inside class:Linear\n",
"-\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([20])\n",
"type(grad_input_i): <class 'NoneType'>\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([10, 20])\n",
"-\n",
"type(grad_output_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_output_i.size(): torch.Size([8, 20])\n",
"==\n",
"\n"
]
}
],
"source": [
"def printnorm(self, input, output):\n",
" # input is a tuple of packed inputs\n",
" # output is a Variable. output.data is the Tensor we are interested\n",
" print('Inside ' + self.__class__.__name__ + ' forward')\n",
" print('-')\n",
" \n",
" if not isinstance(input, tuple):\n",
" input = (input,)\n",
" for input_i in input:\n",
" print('type(input_i): ', type(input_i))\n",
" print('input_i.size(): ', input_i.size())\n",
" \n",
" print('-')\n",
"\n",
" if not isinstance(output, tuple):\n",
" output = (output,)\n",
" for output_i in output:\n",
" print('output_i size:', output_i.data.size())\n",
" print('output_i norm:', output_i.data.norm())\n",
" print('==\\n')\n",
"\n",
" \n",
"def printgradnorm(self, grad_input, grad_output):\n",
" print('Inside ' + self.__class__.__name__ + ' backward')\n",
" print('Inside class:' + self.__class__.__name__)\n",
" print('-')\n",
" \n",
" if not isinstance(grad_input, tuple):\n",
" grad_input = (grad_input,)\n",
" \n",
" for grad_input_i in grad_input:\n",
" print('type(grad_input_i): ', type(grad_input_i))\n",
" if grad_input_i is not None:\n",
" print('grad_input_i.size(): ', grad_input_i.size())\n",
" \n",
" print('-')\n",
" \n",
" if not isinstance(grad_output, tuple):\n",
" grad_output = (grad_output,)\n",
"\n",
" \n",
" for grad_output_i in grad_output:\n",
" print('type(grad_output_i): ', type(grad_output_i))\n",
" if grad_output_i is not None:\n",
" print('grad_output_i.size(): ', grad_output_i.size())\n",
" \n",
" print('==\\n')\n",
" \n",
"fc1 = nn.Linear(in_features=10, out_features=20)\n",
"fc2 = nn.Linear(in_features=20, out_features=30)\n",
" \n",
"fc1.register_backward_hook(printgradnorm)\n",
"fc1.register_forward_hook(printnorm)\n",
"fc2.register_backward_hook(printgradnorm)\n",
"fc2.register_forward_hook(printnorm)\n",
"\n",
"print(fc1)\n",
"print(fc2)\n",
"\n",
"x = Variable(torch.randn(8,10))\n",
"o1 = fc1(x)\n",
"o2 = fc2(o1)\n",
"err = o2.mean()\n",
"err.backward()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Inside Sequential forward\n",
"-\n",
"type(input_i): <class 'torch.autograd.variable.Variable'>\n",
"input_i.size(): torch.Size([8, 10])\n",
"-\n",
"output_i size: torch.Size([8, 30])\n",
"output_i norm: 5.458936063359881\n",
"==\n",
"\n",
"Inside Sequential backward\n",
"Inside class:Sequential\n",
"-\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([30])\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([8, 20])\n",
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_input_i.size(): torch.Size([20, 30])\n",
"-\n",
"type(grad_output_i): <class 'torch.autograd.variable.Variable'>\n",
"grad_output_i.size(): torch.Size([8, 30])\n",
"==\n",
"\n"
]
}
],
"source": [
"## the hook registry is not recursive.. e.g. in the case that you have a Module with submodules.., \n",
"from collections import OrderedDict\n",
"net = nn.Sequential(OrderedDict({\n",
" 'fc1': nn.Linear(in_features=10, out_features=20),\n",
" 'fc2': nn.Linear(in_features=20, out_features=30)}))\n",
" \n",
"net.register_forward_hook(printnorm)\n",
"net.register_backward_hook(printgradnorm)\n",
"\n",
"\n",
"x = Variable(torch.randn(8,10))\n",
"y = net(x)\n",
"err = y.mean()\n",
"err.backward()"
]
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/53ae672f0de85da6db3e69533a00c0d3"
},
"gist": {
"data": {
"description": "registering hooks in pytorch",
"public": true
},
"id": "53ae672f0de85da6db3e69533a00c0d3"
},
"kernelspec": {
"display_name": "t3",
"language": "python",
"name": "t3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment