Created
July 3, 2019 21:17
-
-
Save denisyarats/2a169dd56632cbab8f47ed0d184f1911 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
index ede2865..de5eb9f 100755 | |
--- a/examples/maml-omniglot.py | |
+++ b/examples/maml-omniglot.py | |
@@ -30,6 +30,7 @@ import higher | |
from omniglot_loaders import OmniglotNShot | |
+ | |
def main(): | |
argparser = argparse.ArgumentParser() | |
argparser.add_argument('--n_way', type=int, help='n way', default=5) | |
@@ -58,23 +59,22 @@ def main(): | |
# Before higher, models could *not* be created like this | |
# and the parameters needed to be manually updated and copied | |
# for the updates. | |
+ | |
net = nn.Sequential( | |
- nn.Conv2d(1, 64, 3, 2), | |
- nn.ReLU(inplace=True), | |
- nn.BatchNorm2d(64), | |
- nn.Conv2d(64, 64, 3, 2), | |
+ nn.Conv2d(1, 64, 3), | |
+ nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(inplace=True), | |
- nn.BatchNorm2d(64), | |
- nn.Conv2d(64, 64, 3, 2), | |
+ nn.MaxPool2d(2, 2), | |
+ nn.Conv2d(64, 64, 3), | |
+ nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(inplace=True), | |
- nn.BatchNorm2d(64), | |
- nn.Conv2d(64, 64, 2, 1), | |
+ nn.MaxPool2d(2, 2), | |
+ nn.Conv2d(64, 64, 3), | |
+ nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(inplace=True), | |
- nn.BatchNorm2d(64), | |
+ nn.MaxPool2d(2,2), | |
Flatten(), | |
- nn.Linear(64, args.n_way) | |
- ).to(device) | |
- | |
+ nn.Linear(64, args.n_way)).to(device) | |
# We will use Adam to (meta-)optimize the initial parameters | |
# to be adapted. | |
meta_opt = optim.Adam(net.parameters(), lr=1e-3) | |
@@ -90,6 +90,7 @@ def train(db, net, device, meta_opt, epoch, log): | |
net.train() | |
n_train_iter = db.x_train.shape[0] // db.batchsz | |
+ | |
for batch_idx in range(n_train_iter): | |
# Sample a batch of support and query images and labels. | |
x_spt, y_spt, x_qry, y_qry = db.next() | |
@@ -107,7 +108,7 @@ def train(db, net, device, meta_opt, epoch, log): | |
# Initialize the inner optimizer to adapt the parameters to | |
# the support set. | |
n_inner_iter = 1 | |
- inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1) | |
+ inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) | |
qry_losses = [] | |
qry_accs = [] | |
@@ -167,6 +168,7 @@ def test(db, net, device, epoch, log): | |
qry_losses = [] | |
qry_accs = [] | |
+ | |
for batch_idx in range(n_test_iter): | |
x_spt, y_spt, x_qry, y_qry = db.next('test') | |
@@ -180,8 +182,8 @@ def test(db, net, device, epoch, log): | |
# TODO: Maybe pull this out into a separate module so it | |
# doesn't have to be duplicated between `train` and `test`? | |
- n_inner_iter = 3 | |
- inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1) | |
+ n_inner_iter = 5 | |
+ inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) | |
for i in range(task_num): | |
with higher.innerloop_ctx(net, inner_opt) as (fnet, diffopt): | |
@@ -195,7 +197,7 @@ def test(db, net, device, epoch, log): | |
# The query loss and acc induced by these parameters. | |
qry_logits = fnet(x_qry[i]).detach() | |
- qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') | |
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none').detach() | |
qry_losses.append(qry_loss) | |
qry_accs.append( | |
qry_logits.argmax(dim=1) == y_qry[i] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment