Created
April 25, 2020 09:06
-
-
Save ndronen/bf44e29e98e3d774c621e24242a0ab5c to your computer and use it in GitHub Desktop.
Straw man proposal for PyTorch issue
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a straw man proposal to begin discussion of how to change the | |
PyTorch hooks API to support capture/inspection/modification of | |
keyword arguments. | |
https://github.com/pytorch/pytorch/issues/35643 | |
""" | |
import unittest | |
import torch | |
from collections import OrderedDict | |
class Module: | |
def __init__(self): | |
self._backward_hooks = OrderedDict() | |
self._forward_hooks = OrderedDict() | |
self._forward_pre_hooks = OrderedDict() | |
def forward(self, *args, **kwargs): | |
raise NotImplementedError() | |
def __call__(self, *input, **kwargs): | |
for hook in self._forward_pre_hooks.values(): | |
# The try/except block is an inelegant hack. | |
try: | |
result = hook(self, input, kwargs) | |
except TypeError as e: | |
if 'takes 2 positional' in str(e): | |
result = hook(self, input) | |
else: | |
raise e | |
if result is not None: | |
# Client possibly modified the input. | |
if isinstance(result, tuple): | |
if len(result) == 2 and isinstance(result[0], tuple) \ | |
and isinstance(result[1], dict): | |
# Client possibly modified positional and keyword args. | |
input = result[0] | |
kwargs.update(result[1]) | |
else: | |
# Client possibly modified positional args. | |
input = result | |
else: | |
# Client possibly modified positional args, returned | |
# non-tuple. | |
input = (result,) | |
result = self.forward(*input, **kwargs) | |
# TODO | |
# Forward hooks | |
# Backward hooks | |
return result | |
def register_forward_pre_hook(self, hook): | |
handle = torch.utils.hooks.RemovableHandle(self._forward_pre_hooks) | |
self._forward_pre_hooks[handle.id] = hook | |
return handle | |
def register_forward_hook(self, hook): | |
handle = torch.utils.hooks.RemovableHandle(self._forward_hooks) | |
self._forward_hooks[handle.id] = hook | |
return handle | |
def register_backward_hook(self, hook): | |
handle = torch.utils.hooks.RemovableHandle(self._forward_hooks) | |
self._backward_hooks[handle.id] = hook | |
return handle | |
class BinaryOrTernarySum(Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, y, z=None): | |
output = x + y | |
if z is not None: | |
output += z | |
return output | |
def forward_pre_hook_with_kwargs(module, input, kwargs=None): | |
"""Increment positional and keyword arguments by one. | |
""" | |
assert isinstance(input, tuple) | |
if kwargs is not None: | |
assert isinstance(kwargs, dict) | |
for i in input: | |
i += 1 | |
for k, v in kwargs.items(): | |
v += 1 | |
return input, kwargs | |
def forward_pre_hook_backward_compatibility(module, input): | |
"""Increment positional arguments by one. | |
""" | |
assert isinstance(input, tuple) | |
for i in input: | |
i += 1 | |
return input | |
class TestModuleHooks(unittest.TestCase): | |
def setUp(self): | |
self.n = 2 | |
self.module = BinaryOrTernarySum() | |
self.x = torch.zeros(self.n) | |
self.y = torch.ones(self.n) | |
self.z = torch.ones(self.n) * 2 | |
def test_baseline(self): | |
# Without forward pre-hook, should be 0 + 1 + 2 = 3. | |
expected = torch.ones(self.n) * 3 | |
actual = self.module(self.x, self.y, z=self.z) | |
self.assertTrue(torch.all(actual == expected)) | |
def test_forward_pre_hook_backward_compatibility(self): | |
handle = self.module.register_forward_pre_hook( | |
forward_pre_hook_backward_compatibility | |
) | |
# The result should be x=1 + y=2 + z=2 = 5. | |
expected = torch.ones(self.n) * 5 | |
actual = self.module(self.x, self.y, z=self.z) | |
self.assertTrue(torch.all(actual == expected)) | |
handle.remove() | |
def test_forward_pre_hook_using_kwargs(self): | |
# The result should be x=1 + y=2 + z=3 = 6. | |
handle = self.module.register_forward_pre_hook( | |
forward_pre_hook_with_kwargs | |
) | |
expected = torch.ones(self.n) * 6 | |
actual = self.module(self.x, self.y, z=self.z) | |
self.assertTrue(torch.all(actual == expected)) | |
handle.remove() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment