Forked from anonymous/09242017-registering-hooks.ipynb
Last active
September 25, 2017 17:14
-
-
Save braingineer/9b8205b16c6c2fa50a2afd85a5448742 to your computer and use it in GitHub Desktop.
registering hooks in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"demonstrating the hook registering.\n", | |
"\n", | |
"mostly copy-pasted from http://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch import nn\n", | |
"import torch\n", | |
"from torch.autograd import Variable" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Linear (10 -> 20)\n", | |
"Linear (20 -> 30)\n", | |
"Inside Linear forward\n", | |
"-\n", | |
"type(input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"input_i.size(): torch.Size([8, 10])\n", | |
"-\n", | |
"output_i size: torch.Size([8, 20])\n", | |
"output_i norm: 7.363391333259189\n", | |
"==\n", | |
"\n", | |
"Inside Linear forward\n", | |
"-\n", | |
"type(input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"input_i.size(): torch.Size([8, 20])\n", | |
"-\n", | |
"output_i size: torch.Size([8, 30])\n", | |
"output_i norm: 5.15335982781888\n", | |
"==\n", | |
"\n", | |
"Inside Linear backward\n", | |
"Inside class:Linear\n", | |
"-\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([30])\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([8, 20])\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([20, 30])\n", | |
"-\n", | |
"type(grad_output_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_output_i.size(): torch.Size([8, 30])\n", | |
"==\n", | |
"\n", | |
"Inside Linear backward\n", | |
"Inside class:Linear\n", | |
"-\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([20])\n", | |
"type(grad_input_i): <class 'NoneType'>\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([10, 20])\n", | |
"-\n", | |
"type(grad_output_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_output_i.size(): torch.Size([8, 20])\n", | |
"==\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"def printnorm(self, input, output):\n", | |
" # input is a tuple of packed inputs\n", | |
" # output is a Variable. output.data is the Tensor we are interested\n", | |
" print('Inside ' + self.__class__.__name__ + ' forward')\n", | |
" print('-')\n", | |
" \n", | |
" if not isinstance(input, tuple):\n", | |
" input = (input,)\n", | |
" for input_i in input:\n", | |
" print('type(input_i): ', type(input_i))\n", | |
" print('input_i.size(): ', input_i.size())\n", | |
" \n", | |
" print('-')\n", | |
"\n", | |
" if not isinstance(output, tuple):\n", | |
" output = (output,)\n", | |
" for output_i in output:\n", | |
" print('output_i size:', output_i.data.size())\n", | |
" print('output_i norm:', output_i.data.norm())\n", | |
" print('==\\n')\n", | |
"\n", | |
" \n", | |
"def printgradnorm(self, grad_input, grad_output):\n", | |
" print('Inside ' + self.__class__.__name__ + ' backward')\n", | |
" print('Inside class:' + self.__class__.__name__)\n", | |
" print('-')\n", | |
" \n", | |
" if not isinstance(grad_input, tuple):\n", | |
" grad_input = (grad_input,)\n", | |
" \n", | |
" for grad_input_i in grad_input:\n", | |
" print('type(grad_input_i): ', type(grad_input_i))\n", | |
" if grad_input_i is not None:\n", | |
" print('grad_input_i.size(): ', grad_input_i.size())\n", | |
" \n", | |
" print('-')\n", | |
" \n", | |
" if not isinstance(grad_output, tuple):\n", | |
" grad_output = (grad_output,)\n", | |
"\n", | |
" \n", | |
" for grad_output_i in grad_output:\n", | |
" print('type(grad_output_i): ', type(grad_output_i))\n", | |
" if grad_output_i is not None:\n", | |
" print('grad_output_i.size(): ', grad_output_i.size())\n", | |
" \n", | |
" print('==\\n')\n", | |
" \n", | |
"fc1 = nn.Linear(in_features=10, out_features=20)\n", | |
"fc2 = nn.Linear(in_features=20, out_features=30)\n", | |
" \n", | |
"fc1.register_backward_hook(printgradnorm)\n", | |
"fc1.register_forward_hook(printnorm)\n", | |
"fc2.register_backward_hook(printgradnorm)\n", | |
"fc2.register_forward_hook(printnorm)\n", | |
"\n", | |
"print(fc1)\n", | |
"print(fc2)\n", | |
"\n", | |
"x = Variable(torch.randn(8,10))\n", | |
"o1 = fc1(x)\n", | |
"o2 = fc2(o1)\n", | |
"err = o2.mean()\n", | |
"err.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Inside Sequential forward\n", | |
"-\n", | |
"type(input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"input_i.size(): torch.Size([8, 10])\n", | |
"-\n", | |
"output_i size: torch.Size([8, 30])\n", | |
"output_i norm: 5.458936063359881\n", | |
"==\n", | |
"\n", | |
"Inside Sequential backward\n", | |
"Inside class:Sequential\n", | |
"-\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([30])\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([8, 20])\n", | |
"type(grad_input_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_input_i.size(): torch.Size([20, 30])\n", | |
"-\n", | |
"type(grad_output_i): <class 'torch.autograd.variable.Variable'>\n", | |
"grad_output_i.size(): torch.Size([8, 30])\n", | |
"==\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"## the hook registry is not recursive.. e.g. in the case that you have a Module with submodules.., \n", | |
"from collections import OrderedDict\n", | |
"net = nn.Sequential(OrderedDict({\n", | |
" 'fc1': nn.Linear(in_features=10, out_features=20),\n", | |
" 'fc2': nn.Linear(in_features=20, out_features=30)}))\n", | |
" \n", | |
"net.register_forward_hook(printnorm)\n", | |
"net.register_backward_hook(printgradnorm)\n", | |
"\n", | |
"\n", | |
"x = Variable(torch.randn(8,10))\n", | |
"y = net(x)\n", | |
"err = y.mean()\n", | |
"err.backward()" | |
] | |
} | |
], | |
"metadata": { | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/53ae672f0de85da6db3e69533a00c0d3" | |
}, | |
"gist": { | |
"data": { | |
"description": "registering hooks in pytorch", | |
"public": true | |
}, | |
"id": "53ae672f0de85da6db3e69533a00c0d3" | |
}, | |
"kernelspec": { | |
"display_name": "t3", | |
"language": "python", | |
"name": "t3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment