Created
April 24, 2020 10:38
-
-
Save Multihuntr/72f625400139592a9138db8e8606989d to your computer and use it in GitHub Desktop.
Generic Pytorch Module Wrapper - When nn.Sequential just isn't enough
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# I keep properties on my main nn.Modules. e.g. a list of the training statistics the model is tracking. | |
# I wanted to perform a set of extra actions across multiple different modules without having to | |
# - write those steps into each of the 5+ different model definitions, or | |
# - explicitly expose those values on the wrapper module. | |
# It's fairly trivial, but if you don't use the try: super(), it doesn't keep the `wrapped` property. | |
import torch | |
import torch.nn as nn | |
class Wrapper(nn.Module): | |
def __init__(self, wrapped): | |
super().__init__() | |
self.wrapped = wrapped | |
def forward(self, x): | |
out = self.wrapped(x) | |
# insert fancy logic here | |
return out | |
def __getattr__(self, name): | |
try: | |
return super().__getattr__(name) | |
except AttributeError: | |
if name == "wrapped": | |
raise AttributeError() | |
return getattr(self.wrapped, name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment