Last active
June 7, 2020 18:44
-
-
Save Franklin-Yao/f34fb4f83090521b2149196f21643f3a to your computer and use it in GitHub Desktop.
Lightning and 16-bit precision, range test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Check my git repo styleMix | |
import torch.nn as nn | |
from utils import Model_type, euclidean_dist | |
from torch.nn import functional as F | |
import torch | |
from pytorch_lightning.core.lightning import LightningModule | |
import numpy as np | |
from collections import OrderedDict | |
from time import time | |
from utils import count_acc | |
# --- conventional supervised training --- | |
class BaselineTrain(LightningModule): | |
def __init__(self, encoder, args, loss_type = 'softmax'): | |
super().__init__() | |
self.encoder = encoder | |
self.args = args | |
self.epoch = 0 | |
self.old_time = time() | |
if args.model_type is Model_type.ResNet12: | |
final_feat_dim = 640 | |
else: | |
pass | |
if args.dataset == 'MiniImageNet': | |
n_class = 64 | |
else: | |
pass | |
self.classifier = nn.Linear(final_feat_dim, n_class) | |
self.classifier.bias.data.fill_(0) | |
self.loss_fn = nn.CrossEntropyLoss() | |
def forward(self,data, mode='train'): | |
args = self.args | |
if mode not in ['train']: | |
data = data.view(args.n_way * (args.n_shot + args.n_query), *data.size()[2:]) | |
feature = self.encoder(data) | |
feature = feature.view(args.n_way, args.n_shot + args.n_query, -1) | |
z_support = feature[:, :args.n_shot] | |
z_query = feature[:, args.n_shot:] | |
proto = z_support.view(args.n_way, args.n_shot, -1).mean(1) | |
z_query = z_query.contiguous().view(args.n_way * args.n_query, -1) | |
scores = -euclidean_dist(z_query, proto) / self.args.temperature | |
else: | |
feature = self.encoder(data) | |
scores = self.classifier(feature) | |
return scores | |
def training_step(self, batch, batch_idx): | |
args = self.args | |
data, index_label = batch[0].cuda(), batch[1].cuda() | |
logits = self(data, 'train') | |
label = index_label | |
loss = F.cross_entropy(logits, label) | |
print_freq = 20 | |
if (batch_idx+1)%print_freq == 0: | |
print('Epoch {}, {}/{}, loss={:.4f}'.format(self.epoch, batch_idx, 300, loss.item())) | |
return {'loss':loss} | |
def training_epoch_end(self, outputs): | |
avg_loss = np.mean([x['loss'].item() for x in outputs]) | |
# print('Train loss={:.4f}'.format(avg_loss)) | |
return {'avg loss': avg_loss} | |
def validation_step(self, batch, batch_idx): | |
args = self.args | |
data, index_label = batch[0].cuda(), batch[1].cuda() | |
label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query)) | |
label = label.cuda() | |
logits = self(data, mode='val') | |
loss = F.cross_entropy(logits, label) | |
acc = count_acc(logits, label) | |
return {'val_acc': acc, 'val_loss':loss} | |
def validation_epoch_end(self, outputs): | |
self.epoch = self.epoch+1 | |
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() | |
avg_acc = np.mean([x['val_acc'] for x in outputs]) | |
print('Validation loss={:.4f} acc={:.4f}, time={:.3f}:'.format(avg_loss.item(), avg_acc, time()-self.old_time)) | |
self.old_time = time() | |
return {'val_acc':avg_acc, 'val_loss':avg_loss} | |
def configure_optimizers(self): | |
args = self.args | |
from torch.optim import SGD, lr_scheduler | |
optimizer = SGD(self.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5 * 1e-4) | |
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1) | |
return [optimizer], [scheduler] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import os.path as osp | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from data.datamgr import SimpleDataManager, SetDataManager | |
from methods.StyleMix import StyleMix | |
from utils import Averager, Timer, count_acc,compute_confidence_interval, Model_type, Method_type,\ | |
save_model, load_pretrained_weights, init, resume_model | |
def train_one_epoch(model, scheduler, optimizer, args, train_loader, label, writer, epoch): | |
# for i in range(len(train_loader)): | |
# scheduler.step() | |
# return | |
model.train() | |
print_freq = 10 | |
for i, batch in enumerate(train_loader): | |
data, index_label = batch[0].cuda(), batch[1].cuda() | |
if args.method_type is Method_type.style: | |
logits, logits1 = model(data, 'train') | |
loss = F.cross_entropy(logits, label) + F.cross_entropy(logits1, label) | |
if args.exp_tag in ['same_labels']: | |
p, q = F.softmax(logits, dim=1), F.softmax(logits1, dim=1) | |
loss1 = F.kl_div(p.log(),q, reduction='batchmean') \ | |
+ F.kl_div(q.log(), p, reduction='batchmean') | |
loss = loss + loss1 | |
else: | |
logits = model(data, 'train') | |
if args.method_type is Method_type.baseline: | |
label = index_label | |
loss = F.cross_entropy(logits, label) | |
acc = count_acc(logits, label) | |
if i % print_freq == print_freq - 1: | |
if args.exp_tag in ['same_labels']: | |
print('epoch {}, train {}/{}, loss={:.4f}, KL_loss={:.4f}, acc={:.4f}'.format(epoch, i, | |
len(train_loader), | |
loss.item(), loss1.item(), | |
acc)) | |
else: | |
print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(epoch, i, len(train_loader), loss.item(), acc)) | |
if writer is not None: | |
writer.add_scalar('loss', loss) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
def val(model, args, val_loader, label): | |
model.eval() | |
vl = Averager() | |
va = Averager() | |
with torch.no_grad(): | |
for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)): | |
data, index_label = batch[0].cuda(), batch[1].cuda() | |
logits = model(data, mode = 'val') | |
loss = F.cross_entropy(logits, label) | |
acc = count_acc(logits, label) | |
vl.add(loss.item()) | |
va.add(acc) | |
vl = vl.item() | |
va = va.item() | |
return vl, va | |
def test(model, label, args, few_shot_params): | |
if args.debug: | |
n_test = 10 | |
print_freq = 2 | |
else: | |
n_test = 1000 | |
print_freq = 100 | |
test_file = args.dataset_dir + 'test.json' | |
test_datamgr = SetDataManager(args.exp_tag, test_file, args.dataset_dir, args.image_size, | |
mode = 'val',n_episode = n_test ,**few_shot_params) | |
loader = test_datamgr.get_data_loader(aug=False) | |
test_acc_record = np.zeros((n_test,)) | |
warmup_state = torch.load(osp.join(args.checkpoint_dir, 'max_acc' + '.pth'))['params'] | |
model.load_state_dict(warmup_state, strict=False) | |
model.eval() | |
ave_acc = Averager() | |
with torch.no_grad(): | |
for i, batch in enumerate(loader, 1): | |
data, index_label = batch[0].cuda(), batch[1].cuda() | |
logits = model(data, 'test') | |
acc = count_acc(logits, label) | |
ave_acc.add(acc) | |
test_acc_record[i - 1] = acc | |
if i % print_freq == 0: | |
print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) | |
m, pm = compute_confidence_interval(test_acc_record) | |
# print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'], | |
# ave_acc.item())) | |
print('Test Acc {:.4f} + {:.4f}'.format(m, pm)) | |
acc_str = '%4.2f' % (m * 100) | |
with open(args.save_dir + '/result.txt', 'a') as f: | |
f.write('%s %s\n' % (acc_str, args.name)) | |
def main(): | |
timer = Timer() | |
args, writer = init() | |
if args.exp_tag in ['sen']: | |
if args.test: | |
from sen.main_base import base_test | |
return base_test(args) | |
else: | |
from sen.main_base import base_train | |
return base_train(args) | |
train_file = args.dataset_dir + 'train.json' | |
val_file = args.dataset_dir + 'val.json' | |
few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query) | |
n_episode = 10 if args.debug else 100 | |
if args.method_type is Method_type.baseline: | |
train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=128) | |
train_loader = train_datamgr.get_data_loader(aug = True) | |
else: | |
train_datamgr = SetDataManager(args.exp_tag, train_file, args.dataset_dir, args.image_size, | |
n_episode=n_episode, mode='train', **few_shot_params) | |
train_loader = train_datamgr.get_data_loader(aug=True) | |
val_datamgr = SetDataManager(args.exp_tag, val_file, args.dataset_dir, args.image_size, | |
n_episode=n_episode,mode = 'val', **few_shot_params) | |
val_loader = val_datamgr.get_data_loader(aug=False) | |
if args.model_type is Model_type.ConvNet: | |
pass | |
elif args.model_type is Model_type.ResNet12: | |
# from networks.resnet import resnet12 | |
# encoder = resnet12(exp_tag=args.exp_tag) | |
from networks.sen_backbone import ResNet12 | |
encoder = ResNet12(args.exp_tag) | |
else: | |
raise ValueError('') | |
if args.method_type is Method_type.baseline: | |
# from methods.baselinetrain import BaselineTrain | |
# model = BaselineTrain(encoder, args) | |
from lightning.baseline import BaselineTrain | |
model = BaselineTrain(encoder, args) | |
elif args.method_type is Method_type.protonet: | |
from methods.protonet import ProtoNet | |
model = ProtoNet(encoder, args) | |
elif args.method_type is Method_type.style: | |
model = StyleMix(encoder, args, dropout=0.5) | |
else: | |
raise ValueError('') | |
model = model.cuda() | |
os.environ["KMP_WARNINGS"] = "FALSE" | |
import warnings | |
warnings.filterwarnings("ignore") | |
from pytorch_lightning import Trainer | |
trainer = Trainer(gpus=1, default_root_dir=args.checkpoint_dir, max_epochs=args.max_epoch, | |
val_percent_check=1.0,fast_dev_run=False, profiler=False, progress_bar_refresh_rate=0) | |
print('len of dataloader: '+str(len(train_loader))) | |
# lr_finder = trainer.lr_find(model, train_dataloader=train_loader, val_dataloaders=val_loader) | |
# args.lr = lr_finder.suggestion() | |
# print('learning rate suggested by lightning: ' + str(args.lr)) | |
# model.hparams.lr = args.lr | |
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader) | |
return | |
from torch.optim import SGD, lr_scheduler | |
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5 * 1e-4) | |
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=args.max_epoch) | |
# scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1) | |
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
args.ngpu = torch.cuda.device_count() | |
torch.backends.cudnn.benchmark = True | |
model = model.cuda() | |
label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query)) | |
label = label.cuda() | |
if args.test: | |
test(model, label, args, few_shot_params) | |
return | |
if args.resume: | |
resume_OK = resume_model(model, optimizer, args) | |
else: | |
resume_OK = False | |
if (not resume_OK) and (args.warmup is not None): | |
load_pretrained_weights(model, args) | |
max_acc = 0.0 | |
if args.debug: | |
args.max_epoch = args.start_epoch + 1 | |
for epoch in range(args.start_epoch, args.max_epoch): | |
print('learning rate: '+str(optimizer.param_groups[0]['lr'])) | |
train_one_epoch(model, scheduler, optimizer, args, train_loader, label,writer, epoch) | |
# continue | |
vl, va = val(model, args, val_loader, label) | |
if writer is not None: | |
writer.add_scalar('data/val_acc', float(va), epoch) | |
print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va)) | |
if va >= max_acc: | |
max_acc = va | |
print('saving the best model! acc={:.4f}'.format(va)) | |
save_model(model, optimizer, args, epoch, 'max_acc') | |
save_model(model, optimizer, args, epoch, 'epoch-last') | |
if epoch != 0: | |
print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) | |
# return | |
if writer is not None: | |
writer.close() | |
# Test Phase | |
test(model, label, args, few_shot_params) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment