Created
November 8, 2021 23:36
-
-
Save poutyface/04b38696a93d37031086ab9833f93541 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
import math | |
import random | |
from PIL import ImageOps, ImageEnhance, ImageFilter | |
import numpy as np | |
from RandAugment import RandAugment | |
#import torch_xla | |
#import torch_xla.core.xla_model as xm | |
#device = xm.xla_device() | |
TPU=False | |
HALF=False | |
print(torch.__version__) | |
torch.backends.cudnn.benchmark=True | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
#if TPU: | |
# device = xm.xla_device() | |
print(device) | |
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] | |
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] | |
transform_train = transforms.Compose([ | |
transforms.RandomResizedCrop(32, scale=(0.8, 1.0), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
RandAugment(2, 5), | |
#transforms.RandomResizedCrop(32, scale=(0.8, 1.0)), | |
#transforms.RandomHorizontalFlip(), | |
#RandAugment(2, 5), | |
#EMB(), | |
#transforms.Resize((64,64)), | |
#transforms.RandomCrop(32, padding=4), | |
#CropUpper(), | |
#transforms.RandomVerticalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(CIFAR_MEAN, CIFAR_STD) | |
#transforms.RandomErasing(value="random"), | |
]) | |
transform_test = transforms.Compose([ | |
transforms.Resize(32, torchvision.transforms.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | |
]) | |
class Cutout(object): | |
def __init__(self, length): | |
self.length = length | |
def __call__(self, img): | |
h, w = img.size(1), img.size(2) | |
mask = np.ones((h, w), np.float32) | |
y = np.random.randint(h) | |
x = np.random.randint(w) | |
y1 = np.clip(y - self.length // 2, 0, h) | |
y2 = np.clip(y + self.length // 2, 0, h) | |
x1 = np.clip(x - self.length // 2, 0, w) | |
x2 = np.clip(x + self.length // 2, 0, w) | |
mask[y1: y2, x1: x2] = 0. | |
mask = torch.from_numpy(mask) | |
mask = mask.expand_as(img) | |
img *= mask | |
return img | |
#transform_train.transforms.append(Cutout(8)) | |
#transform_train.transforms.append(Cutout(10)) | |
BATCH_SIZE = 64 | |
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train) | |
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) | |
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | |
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) | |
class AuxiliaryHeadCIFAR(nn.Module): | |
def __init__(self, C, num_classes): | |
"""assuming input size 8x8""" | |
super(AuxiliaryHeadCIFAR, self).__init__() | |
self.features = nn.Sequential( | |
nn.ReLU(inplace=True), | |
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | |
nn.Conv2d(C, 128, 1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 768, 2, bias=False), | |
nn.BatchNorm2d(768), | |
nn.ReLU(inplace=True) | |
) | |
self.classifier = nn.Linear(768, num_classes) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x.view(x.size(0),-1)) | |
return x | |
class ImageProcess(nn.Module): | |
def __init__(self, ch=32): | |
super().__init__() | |
self.conv = nn.Sequential( | |
nn.Conv2d(3, ch, 3, 1, 1, bias=True), | |
nn.SiLU(), | |
#nn.Conv2d(32, ch, 1, bias=True), | |
nn.MaxPool2d(3, 2, 1) | |
) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, emb_dim, n_head): | |
super().__init__() | |
# key, query, value projections for all heads | |
self.key = nn.Linear(emb_dim, emb_dim, bias=False) | |
self.query = nn.Linear(emb_dim, emb_dim, bias=False) | |
self.value = nn.Linear(emb_dim, emb_dim, bias=False) | |
self.attn_drop = nn.Dropout(0.1) | |
# output projection | |
self.proj = nn.Linear(emb_dim, emb_dim) | |
self.n_head = n_head | |
self.norm = nn.LayerNorm(emb_dim) | |
def forward(self, x): | |
B, T, C = x.size() | |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
att = F.softmax(att, dim=-1) | |
att = self.attn_drop(att) | |
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
#y = self.norm(y) | |
#y = F.gelu(y) | |
# output projection | |
y = self.proj(y) | |
return y | |
class SelfAttention(nn.Module): | |
""" an unassuming Transformer block """ | |
def __init__(self, emb_dim, n_head, s): | |
super().__init__() | |
self.emb_dim = emb_dim | |
self.s = s | |
self.c = nn.Conv2d(emb_dim, emb_dim, 3, 1, 1, groups=emb_dim, bias=False) | |
self.norm1 = nn.LayerNorm(emb_dim) | |
self.norm2 = nn.LayerNorm(emb_dim) | |
#self.norm1 = nn.BatchNorm1d(emb_dim) | |
#self.norm2 = nn.BatchNorm1d(emb_dim) | |
self.norm3 = nn.LayerNorm(emb_dim) | |
self.attn = Attention(emb_dim, n_head) | |
self.mlp = nn.Sequential( | |
nn.Linear(emb_dim, 2 * emb_dim), | |
#nn.LayerNorm(2*emb_dim), | |
nn.GELU(), | |
#nn.LayerNorm(2*emb_dim), | |
nn.Linear(2 * emb_dim , emb_dim), | |
#nn.GELU(), | |
#nn.LayerNorm(emb_dim) | |
) | |
def drop_path(self, x, drop_prob): | |
if self.training: | |
keep_prob = 1.-drop_prob | |
mask = torch.FloatTensor(x.size(0), 1, 1).bernoulli_(keep_prob).to(device) | |
x.div_(keep_prob) | |
x.mul_(mask) | |
return x | |
def forward(self, x): | |
B, _, _ = x.size() | |
x1 = x.reshape([B, self.emb_dim, self.s, self.s]) | |
x1 = self.c(x1) | |
x1 = x1.reshape([B, self.s*self.s, self.emb_dim]) | |
x = x + self.drop_path(self.attn(self.norm1(x1)), 0.0) | |
x = x + self.drop_path(self.mlp(self.norm2(x)), 0.0) | |
#x = x + self.norm1(self.drop_path(self.attn(x), 0.0)) | |
#x = x + self.norm2(self.drop_path(self.mlp(x), 0.0)) | |
return x | |
import math | |
class Encoder(nn.Module): | |
def __init__(self, size=16, emb_dim=32, n_head=4, n_layers=1): | |
super().__init__() | |
sizes = [size, size//2, size//4] | |
self.b1 = nn.Sequential(*[SelfAttention(32, 2, 16) for _ in range(1)]) | |
self.l1 = nn.Linear(32, 92) | |
self.p1 = nn.MaxPool2d(2, 2) | |
self.b2 = nn.Sequential(*[SelfAttention(92, 2, 8) for _ in range(1)]) | |
self.l2 = nn.Linear(92, 256) | |
self.p2 = nn.MaxPool2d(2, 2) | |
self.b3 = nn.Sequential(*[SelfAttention(256, 2, 4) for _ in range(2)]) | |
def forward(self, x): | |
B, _, _ = x.size() | |
x = self.b1(x) | |
x = x.reshape([B, 32, 16, 16]) | |
x = self.p1(x) | |
x = x.reshape([B, 8*8, 32]) | |
x = self.l1(x) | |
x = self.b2(x) | |
#x1 = x | |
x = x.reshape([B, 92, 8, 8]) | |
x = self.p2(x) | |
x = x.reshape([B, 4*4, 92]) | |
x = self.l2(x) | |
x = self.b3(x) | |
return x | |
class TF(nn.Module): | |
def __init__(self, n_class=10, emb_dim=192, n_head=3, n_layers=7): | |
super().__init__() | |
self.image_process = ImageProcess(32) | |
self.flattener = nn.Flatten(2, 3) | |
self.enc = Encoder(size=32, emb_dim=emb_dim, n_head=n_head, n_layers=n_layers) | |
#self.pe = nn.Parameter(torch.zeros(1, 16*16, 32), requires_grad=True) | |
self.pe = nn.Parameter(self.emb_sine(16*16, 32), requires_grad=False) | |
self.seq_pool = nn.Linear(emb_dim, 1) | |
self.last_layer = nn.Linear(emb_dim, n_class) | |
def emb_sine(self, n_channels, dim): | |
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] | |
for p in range(n_channels)]) | |
pe[:, 0::2] = torch.sin(pe[:, 0::2]) | |
pe[:, 1::2] = torch.cos(pe[:, 1::2]) | |
return pe.unsqueeze(0) | |
def forward(self, x): | |
x = self.image_process(x) | |
""" | |
B, C, H, W = x.size() | |
axis_pos = [torch.linspace(-1.0, 1.0, size) for size in (H, W)] | |
p = torch.stack(torch.meshgrid(*axis_pos), dim=-1).to(device) | |
p = p.permute(2, 0, 1) | |
p = p.repeat(B, 1, 1, 1) | |
x = torch.cat([x, p], dim=1) | |
""" | |
x = self.flattener(x).transpose(-2, -1) | |
x += self.pe | |
x = self.enc(x) | |
#x1 = x1.mean(dim=1) | |
x = x.mean(dim=1) | |
#print(x.shape) | |
#print(x1.shape) | |
#x = torch.cat([x1, x], dim=1) | |
#print(x.shape) | |
#x = torch.matmul(F.softmax(self.seq_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) | |
x = self.last_layer(x) | |
return x | |
def CT(): | |
model = TF(n_class=10, emb_dim=256, n_head=1, n_layers=6) | |
return model | |
def resnet18(pretrained=False, **kwargs): | |
model = BNet() | |
return model | |
class LabelSmoothingCrossEntropy(nn.Module): | |
""" | |
NLL loss with label smoothing. | |
""" | |
def __init__(self, smoothing=0.1): | |
""" | |
Constructor for the LabelSmoothing module. | |
:param smoothing: label smoothing factor | |
""" | |
super(LabelSmoothingCrossEntropy, self).__init__() | |
assert smoothing < 1.0 | |
self.smoothing = smoothing | |
self.confidence = 1. - smoothing | |
def _compute_losses(self, x, target): | |
log_prob = F.log_softmax(x, dim=-1) | |
nll_loss = -log_prob.gather(dim=-1, index=target.unsqueeze(1)) | |
nll_loss = nll_loss.squeeze(1) | |
smooth_loss = -log_prob.mean(dim=-1) | |
loss = self.confidence * nll_loss + self.smoothing * smooth_loss | |
return loss | |
def forward(self, x, target): | |
return self._compute_losses(x, target).mean() | |
#net = resnet18() | |
net = CT() | |
net = net.to(device) | |
if HALF: | |
net = net.half() | |
train_criterion = LabelSmoothingCrossEntropy().to(device) | |
criterion = nn.CrossEntropyLoss().to(device) | |
biases = [] | |
weights = [] | |
""" | |
optimizer = optim.SGD([ | |
{"params": weights}, | |
#{"params": biases} | |
{"params": biases, "weight_decay": 0.0} | |
], lr=1e-1, momentum=0.9, weight_decay=1e-4, nesterov=True) | |
""" | |
""" | |
optimizer = optim.SGD( | |
net.parameters(), | |
lr=1e-3, | |
momentum=0.9, | |
weight_decay=1e-6) | |
""" | |
optimizer = optim.AdamW( | |
net.parameters(), | |
lr=0.001, | |
weight_decay=1e-4 | |
) | |
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,60], gamma=0.1) | |
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99) | |
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200) | |
EPOCH = 200 | |
def train(epoch, optimizer, scheduler): | |
print('\nEpoch: %d' % epoch) | |
net.to(device) | |
net.train() | |
print(scheduler.get_last_lr()) | |
train_loss = 0 | |
correct = 0 | |
total = 0 | |
cos_sim = nn.CosineSimilarity(dim=1).to(device) | |
for batch_idx, (inputs, targets) in enumerate(train_loader): | |
#net.drop_path_prob = 0.1 * epoch / EPOCH | |
inputs, targets = inputs.to(device), targets.to(device) | |
if HALF: | |
inputs, targets = inputs.half(), targets.half() | |
optimizer.zero_grad() | |
outputs = net(inputs) | |
loss = train_criterion(outputs, targets) | |
#with torch.no_grad(): | |
# out2 = net(inputs) | |
#out2 = out2.detach() | |
#loss = loss*0.1 + (-0.9 * cos_sim(outputs, out2).mean()) | |
loss.backward() | |
if TPU: | |
xm.optimizer_step(optimizer, barrier=True) | |
else: | |
optimizer.step() | |
train_loss += loss.item() | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
if batch_idx % 200 == 0: | |
print('%.3f | %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) | |
print('Acc: %.3f%% (%d/%d)' % (100.*correct/total, correct, total)) | |
scheduler.step() | |
def test(epoch, optimizer): | |
net.eval() | |
test_loss = 0 | |
correct = 0 | |
correct_aux = 0 | |
total = 0 | |
with torch.no_grad(): | |
for batch_idx, (inputs, targets) in enumerate(test_loader): | |
#optimizer.zero_grad() | |
inputs, targets = inputs.to(device), targets.to(device) | |
if HALF: | |
inputs, targets = inputs.half(), targets.half() | |
outputs = net(inputs) | |
_, predicted = outputs.max(1) | |
total += targets.size(0) | |
correct += predicted.eq(targets).sum().item() | |
#_, predicted = aux.max(1) | |
#correct_aux += predicted.eq(targets).sum().item() | |
print('Eval %.3f%% (%d/%d)' % (100.*correct/total, correct, total)) | |
#print('Eval %.3f%% (%d/%d)' % (100.*correct_aux/total, correct_aux, total)) | |
for epoch in range(0, EPOCH): | |
train(epoch, optimizer, scheduler) | |
test(epoch, optimizer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment