Skip to content

Instantly share code, notes, and snippets.

@archydeberker
Created April 14, 2019 21:33
Show Gist options
  • Save archydeberker/a5052560141c7ba77dc1286d7e441a83 to your computer and use it in GitHub Desktop.
Save archydeberker/a5052560141c7ba77dc1286d7e441a83 to your computer and use it in GitHub Desktop.
test_acc = {}
val_acc = {}
train_acc = {}
test_loss = {}
for train_size in dataset_size:
print('Training with subset %1.4f, which is %d images'%(train_size, train_size*total_train))
net = Net()
# Train model with an early stopping criterion - terminates after 4 epochs of non-improving val loss
net, loss_list, val_list = train_model(net, trainset_loaders[train_size], valloader, 1000, n_epochs=10)
test_accuracy, loss = test_model(net, testloader)
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * accuracy))
test_acc[train_size] = accuracy
test_loss[train_size] = loss
val_acc[train_size] = val_list
train_acc[train_size] = loss_list
torch.save(net, 'trainset_%1.2f_%d_images.model'%(train_size, train_size*total_train))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment