-
Star
(115)
You must be signed in to star a gist -
Fork
(6)
You must be signed in to fork a gist
-
-
Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
import torch | |
import torchvision | |
class VGGPerceptualLoss(torch.nn.Module): | |
def __init__(self, resize=True): | |
super(VGGPerceptualLoss, self).__init__() | |
blocks = [] | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) | |
for bl in blocks: | |
for p in bl.parameters(): | |
p.requires_grad = False | |
self.blocks = torch.nn.ModuleList(blocks) | |
self.transform = torch.nn.functional.interpolate | |
self.resize = resize | |
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): | |
if input.shape[1] != 3: | |
input = input.repeat(1, 3, 1, 1) | |
target = target.repeat(1, 3, 1, 1) | |
input = (input-self.mean) / self.std | |
target = (target-self.mean) / self.std | |
if self.resize: | |
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) | |
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) | |
loss = 0.0 | |
x = input | |
y = target | |
for i, block in enumerate(self.blocks): | |
x = block(x) | |
y = block(y) | |
if i in feature_layers: | |
loss += torch.nn.functional.l1_loss(x, y) | |
if i in style_layers: | |
act_x = x.reshape(x.shape[0], x.shape[1], -1) | |
act_y = y.reshape(y.shape[0], y.shape[1], -1) | |
gram_x = act_x @ act_x.permute(0, 2, 1) | |
gram_y = act_y @ act_y.permute(0, 2, 1) | |
loss += torch.nn.functional.l1_loss(gram_x, gram_y) | |
return loss |
Also while feeding the target
i.e y
can't we wrap it up inside torch.no_grad()
to save computation? This is because under no circumstances gradients will be needed to backpropagate to the target
. Only prediction needs to backpropagation so should not be wrapped under.
In my understanding as long as images are roughly in the range [-1,1] (can exceed no problem) at the time of feeding to VGG things are fine. But if at the time of feeding to VGG the images are purely +ve i.e [0,1+delta] I dont know, because at the time of training the images fed to VGG had both +ve and -ve values.
I just notice this from Keras docs:
For VGG16, call
keras.applications.vgg16.preprocess_input
on your inputs before passing them to the model.vgg16.preprocess_input
will convert the input images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling.
So I guess we don't need to specifically rescale the inputs, as long as they are normalized to ImageNet's mean and std. However, should we also handle the conversion from RGB to BGR, or just assume the input image channels already have that order?
Also while feeding the
target
i.ey
can't we wrap it up insidetorch.no_grad()
to save computation? This is because under no circumstances gradients will be needed to backpropagate to thetarget
. Only prediction needs to backpropagation so should not be wrapped under.
Yes, I believe we don't need to track the gradient for the target when feeding it through VGG.
In my understanding as long as images are roughly in the range [-1,1] (can exceed no problem) at the time of feeding to VGG things are fine. But if at the time of feeding to VGG the images are purely +ve i.e [0,1+delta] I dont know, because at the time of training the images fed to VGG had both +ve and -ve values.