Last active
November 25, 2023 03:52
-
-
Save oscarknagg/45b187c236c6262b1c4bbe2d0920ded6 to your computer and use it in GitHub Desktop.
Gist for projected gradient descent adversarial attack using PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def projected_gradient_descent(model, x, y, loss_fn, num_steps, step_size, step_norm, eps, eps_norm, | |
clamp=(0,1), y_target=None): | |
"""Performs the projected gradient descent attack on a batch of images.""" | |
x_adv = x.clone().detach().requires_grad_(True).to(x.device) | |
targeted = y_target is not None | |
num_channels = x.shape[1] | |
for i in range(num_steps): | |
_x_adv = x_adv.clone().detach().requires_grad_(True) | |
prediction = model(_x_adv) | |
loss = loss_fn(prediction, y_target if targeted else y) | |
loss.backward() | |
with torch.no_grad(): | |
# Force the gradient step to be a fixed size in a certain norm | |
if step_norm == 'inf': | |
gradients = _x_adv.grad.sign() * step_size | |
else: | |
# Note .view() assumes batched image data as 4D tensor | |
gradients = _x_adv.grad * step_size / _x_adv.grad.view(_x_adv.shape[0], -1)\ | |
.norm(step_norm, dim=-1)\ | |
.view(-1, num_channels, 1, 1) | |
if targeted: | |
# Targeted: Gradient descent with on the loss of the (incorrect) target label | |
# w.r.t. the image data | |
x_adv -= gradients | |
else: | |
# Untargeted: Gradient ascent on the loss of the correct label w.r.t. | |
# the model parameters | |
x_adv += gradients | |
# Project back into l_norm ball and correct range | |
if eps_norm == 'inf': | |
# Workaround as PyTorch doesn't have elementwise clip | |
x_adv = torch.max(torch.min(x_adv, x + eps), x - eps) | |
else: | |
delta = x_adv - x | |
# Assume x and x_adv are batched tensors where the first dimension is | |
# a batch dimension | |
mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps | |
scaling_factor = delta.view(delta.shape[0], -1).norm(norm, dim=1) | |
scaling_factor[mask] = eps | |
# .view() assumes batched images as a 4D Tensor | |
delta *= eps / scaling_factor.view(-1, 1, 1, 1) | |
x_adv = x + delta | |
x_adv = x_adv.clamp(*clamp) | |
return x_adv.detach() |
x_adv = x_adv.clamp(*clamp)
Can someone explain the need for this clamping?
It's equivalent to x_adv = x_adv.clamp(0, 1)
.
*
unpacks 'clamp'
Does the norm
here mean eps_norm
? (L45)
mask = delta.view(delta.shape[0], -1).norm(norm, dim=1) <= eps
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
x_adv = x_adv.clamp(*clamp)
Can someone explain the need for this clamping?