Created
September 14, 2017 02:23
-
-
Save crcrpar/a5d46738ffff08fc12138a5f270db426 to your computer and use it in GitHub Desktop.
[PyTorch] pre-trained VGG16 for perceptual loss. e.g. Style Transfer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Modified VGG16 to compute perceptual loss. | |
This class is mostly copied from pytorch/examples. | |
See, fast_neural_style in https://github.com/pytorch/examples. | |
""" | |
import torch | |
from torchvision import models | |
class VGG_OUTPUT(object): | |
def __init__(self, relu1_2, relu2_2, relu3_3, relu4_3): | |
self.__dict__ = locals() | |
class VGG16(torch.nn.Module): | |
def __init__(self, requires_grad=False): | |
super(VGG16, self).__init__() | |
vgg_pretrained_features = models.vgg16(pretrained=True).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
for x in range(4): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(4, 9): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(9, 16): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(16, 23): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
h = self.slice1(X) | |
h_relu1_2 = h | |
h = self.slice2(h) | |
h_relu2_2 = h | |
h = self.slice3(h) | |
h_relu3_3 = h | |
h = self.slice4(h) | |
h_relu4_3 = h | |
return VGG_OUTPUT(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/vgg.py