Created
September 16, 2019 07:19
-
-
Save koshian2/7b850e453bd9d35ecd5d073876a150b5 to your computer and use it in GitHub Desktop.
Normalization vs gradients
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
from models import TenLayersModel, ResNetLikeModel | |
from torchvision import transforms | |
from tqdm import tqdm | |
import numpy as np | |
import os | |
import pickle | |
def load_cifar(): | |
trans = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=trans) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) | |
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=trans) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) | |
return trainloader, testloader | |
def calc_gradients(model, input_x, target_y): | |
model.zero_grad() | |
loss_func = torch.nn.CrossEntropyLoss() | |
loss = loss_func(model(input_x), target_y) | |
loss.backward() | |
grads = [p.grad.norm().item() for p in model.parameters() if len(p.size()) == 4] | |
return grads | |
def main(network, normalization): | |
if network == "ten": | |
model = TenLayersModel(normalization) | |
elif network == "resnet": | |
model = ResNetLikeModel(normalization) | |
model_name = f"{network}_{normalization}" | |
output_dir = "snapshot" | |
if not os.path.exists(output_dir): | |
os.mkdir(output_dir) | |
device = "cuda" | |
batch_size = 128 | |
model.to(device) | |
model = torch.nn.DataParallel(model) | |
trainloader, testloader = load_cifar() | |
log_gradients = [] | |
log_loss = [] | |
log_val_acc = [] | |
weight_decay = 0.0001 if network == "resnet" else 0 | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay) | |
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1) | |
criterion = torch.nn.CrossEntropyLoss() | |
max_val_acc = 0.0 | |
for epoch in tqdm(range(100)): | |
# gradient check | |
gradients = [] | |
for X, y in testloader: | |
if len(X) != 128: continue | |
X, y = X.to(device), y.to(device) | |
gradients.append(calc_gradients(model, X, y)) | |
model.zero_grad() | |
layer_grads = np.mean(np.array(gradients), axis=0) # batch-wise mean | |
log_gradients.append(layer_grads) | |
# train | |
train_loss = 0.0 | |
for i, (X, y) in enumerate(trainloader): | |
X, y = X.to(device), y.to(device) | |
optimizer.zero_grad() | |
y_pred = model(X) | |
loss = criterion(y_pred, y) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
log_loss.append(train_loss / (i + 1)) | |
# validation | |
with torch.no_grad(): | |
correct, total = 0, 0 | |
for X, y in testloader: | |
X, y = X.to(device), y.to(device) | |
outputs = model(X) | |
_, pred = torch.max(outputs.data, 1) | |
total += y.size(0) | |
correct += (pred == y).sum().item() | |
log_val_acc.append(correct / total) | |
# save model | |
if max_val_acc < log_val_acc[-1]: | |
torch.save(model.state_dict(), f"{output_dir}/{model_name}.pytorch") | |
max_val_acc = log_val_acc[-1] | |
scheduler.step() | |
print("Epoch =", epoch, "Loss =", log_loss[-1], "Val_acc =", log_val_acc[-1], "/ ", model_name) | |
# print(log_gradients[-1]) | |
# save result | |
with open(f"{output_dir}/log_{model_name}.pkl", "wb") as fp: | |
result = {"gradient":log_gradients, "loss":log_loss, "val_acc":log_val_acc} | |
pickle.dump(result, fp) | |
if __name__ == "__main__": | |
for model in ["ten", "resnet"]: | |
for norm in ["batch", "instance", "spectral"]: | |
main(model, norm) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision | |
from torchvision import transforms | |
from torch.nn.utils.spectral_norm import spectral_norm | |
class TenLayersModel(nn.Module): | |
def __init__(self, normalization): | |
super().__init__() | |
self.convs = self.create_model(normalization) | |
self.linear = nn.Linear(256, 10) | |
def conv_norm_relu(self, in_ch, out_ch, normalization): | |
layers = [] | |
if normalization == "spectral": | |
w = nn.Conv2d(in_ch, out_ch, 3, padding=1) | |
layers.append(spectral_norm(w)) | |
else: | |
layers.append(nn.Conv2d(in_ch, out_ch, 3, padding=1)) | |
if normalization == "batch": | |
layers.append(nn.BatchNorm2d(out_ch)) | |
elif normalization == "instance": | |
layers.append(nn.InstanceNorm2d(out_ch)) | |
layers.append(nn.ReLU(True)) | |
return layers | |
def create_model(self, normalization): | |
layers = [] | |
for in_ch in [3, 64, 64]: | |
layers += self.conv_norm_relu(in_ch, 64, normalization) | |
layers.append(nn.AvgPool2d(2)) | |
for in_ch in [64, 128, 128]: | |
layers += self.conv_norm_relu(in_ch, 128, normalization) | |
layers.append(nn.AvgPool2d(2)) | |
for in_ch in [128, 256, 256]: | |
layers += self.conv_norm_relu(in_ch, 256, normalization) | |
layers.append(nn.AvgPool2d(8)) | |
return nn.Sequential(*layers) | |
def forward(self, inputs): | |
x = self.convs(inputs).view(inputs.size(0), -1) | |
x = self.linear(x) | |
return x | |
class ResNetPreactModule(nn.Module): | |
def __init__(self, in_ch, out_ch, downsampling, normalization): | |
assert normalization in ["batch", "spectral", "instance"] | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) | |
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) | |
self.shortcut_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else None | |
self.norm1 = self.get_normalization(in_ch, normalization) | |
self.norm2 = self.get_normalization(out_ch, normalization) | |
if normalization == "spectral": | |
self.conv1 = spectral_norm(self.conv1) | |
self.conv2 = spectral_norm(self.conv2) | |
self.shortcut_conv = spectral_norm(self.shortcut_conv) if self.shortcut_conv is not None else None | |
self.downsampling = nn.AvgPool2d(downsampling) if downsampling > 1 else None | |
def get_normalization(self, ch, normalization): | |
if normalization == "batch": | |
return nn.BatchNorm2d(ch) | |
elif normalization == "instance": | |
return nn.InstanceNorm2d(ch) | |
else: | |
return None | |
def forward(self, inputs): | |
# main path | |
x = self.norm1(inputs) if self.norm1 is not None else inputs | |
x = F.relu(x) | |
x = self.conv1(x) | |
x = self.norm2(x) if self.norm2 is not None else x | |
x = F.relu(x) | |
x = self.conv2(x) | |
# shortcut path | |
shortcut = self.shortcut_conv(inputs) if self.shortcut_conv is not None else inputs | |
# downsampling | |
if self.downsampling is not None: | |
x = self.downsampling(x) | |
shortcut = self.downsampling(shortcut) | |
return x + shortcut | |
class ResNetLikeModel(nn.Module): | |
def __init__(self, normalization): | |
super().__init__() | |
self.conv = nn.Sequential( | |
*self.resnet_block(3, 64, 3, normalization), | |
*self.resnet_block(64, 128, 4, normalization), | |
*self.resnet_block(128, 256, 6, normalization, enable_downsampling=False), | |
nn.AvgPool2d(8) | |
) | |
if normalization == "batch": | |
self.last_norm = nn.BatchNorm2d(256) | |
elif normalization == "instance": | |
self.last_norm = nn.InstanceNorm2d(256) | |
else: | |
self.last_norm = None | |
self.linear = nn.Linear(256, 10) | |
def resnet_block(self, in_ch, out_ch, reps, normalization, enable_downsampling=True): | |
layers = [] | |
for i in range(reps): | |
current_in = in_ch if i == 0 else out_ch | |
down = 2 if i == reps-1 and enable_downsampling else 1 | |
layers.append(ResNetPreactModule(current_in, out_ch, down, normalization)) | |
return layers | |
def forward(self, inputs): | |
x = self.conv(inputs) | |
x = self.last_norm(x) if self.last_norm is not None else x | |
x = x.view(x.size(0), -1) | |
x = self.linear(x) | |
return x | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment