Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active April 8, 2025 08:10
Show Gist options
  • Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
Save alper111/8233cdb0414b4cb5853f2f730ab95a49 to your computer and use it in GitHub Desktop.
PyTorch implementation of VGG perceptual loss
import torch
import torchvision
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl.parameters():
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks)
self.transform = torch.nn.functional.interpolate
self.resize = resize
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss
@cdalinghaus
Copy link

My grayscale image data had no explicit color channel, so I've added a small check for that:

# Input is greyscale and of shape (batch, x, y) instead of (batch, 1, x, y)
# Add a color dimension
if len(input.shape) == 3:
    input = input.unsqueeze(1)
    target = target.unsqueeze(1)

Also, to remove the deprecation warning:

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[9:16].eval())blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT).features[16:23].eval())

@chiehwangs
Copy link

Very useful tool! I am very confused. When I use

blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())

the GPU usage will rise sharply in the middle of training, and it will suddenly increase by about 7G!

But when I use

blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[:4].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[4:9].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[9:16].eval())
blocks.append(torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT, ).features[16:23].eval())

does not have this problem?

That's weird, can someone tell me why?

@ndming
Copy link

ndming commented Nov 17, 2024

Hi, thanks for sharing your implementation. I would like to know if there's a specific requirement for the input image range since I plan to use perceptual loss on HDR images whose ranges exceed [0, 1]. Should I rescale or clamp them to [0, 1] before feeding to the loss? I notice your loss handles the normalization but it seems to expect the inputs to be in the range [0, 1].

@MohitLamba94
Copy link

In my understanding as long as images are roughly in the range [-1,1] (can exceed no problem) at the time of feeding to VGG things are fine. But if at the time of feeding to VGG the images are purely +ve i.e [0,1+delta] I dont know, because at the time of training the images fed to VGG had both +ve and -ve values.

@MohitLamba94
Copy link

Also while feeding the target i.e y can't we wrap it up inside torch.no_grad() to save computation? This is because under no circumstances gradients will be needed to backpropagate to the target. Only prediction needs to backpropagation so should not be wrapped under.

@ndming
Copy link

ndming commented Nov 17, 2024

In my understanding as long as images are roughly in the range [-1,1] (can exceed no problem) at the time of feeding to VGG things are fine. But if at the time of feeding to VGG the images are purely +ve i.e [0,1+delta] I dont know, because at the time of training the images fed to VGG had both +ve and -ve values.

I just notice this from Keras docs:

For VGG16, call keras.applications.vgg16.preprocess_input on your inputs before passing them to the model. vgg16.preprocess_input will convert the input images from RGB to BGR, then will zero-center each color channel with respect to the ImageNet dataset, without scaling.

So I guess we don't need to specifically rescale the inputs, as long as they are normalized to ImageNet's mean and std. However, should we also handle the conversion from RGB to BGR, or just assume the input image channels already have that order?

@ndming
Copy link

ndming commented Nov 17, 2024

Also while feeding the target i.e y can't we wrap it up inside torch.no_grad() to save computation? This is because under no circumstances gradients will be needed to backpropagate to the target. Only prediction needs to backpropagation so should not be wrapped under.

Yes, I believe we don't need to track the gradient for the target when feeding it through VGG.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment