Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save denisyarats/2a169dd56632cbab8f47ed0d184f1911 to your computer and use it in GitHub Desktop.
Save denisyarats/2a169dd56632cbab8f47ed0d184f1911 to your computer and use it in GitHub Desktop.
index ede2865..de5eb9f 100755
--- a/examples/maml-omniglot.py
+++ b/examples/maml-omniglot.py
@@ -30,6 +30,7 @@ import higher
from omniglot_loaders import OmniglotNShot
+
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
@@ -58,23 +59,22 @@ def main():
# Before higher, models could *not* be created like this
# and the parameters needed to be manually updated and copied
# for the updates.
+
net = nn.Sequential(
- nn.Conv2d(1, 64, 3, 2),
- nn.ReLU(inplace=True),
- nn.BatchNorm2d(64),
- nn.Conv2d(64, 64, 3, 2),
+ nn.Conv2d(1, 64, 3),
+ nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
- nn.BatchNorm2d(64),
- nn.Conv2d(64, 64, 3, 2),
+ nn.MaxPool2d(2, 2),
+ nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
- nn.BatchNorm2d(64),
- nn.Conv2d(64, 64, 2, 1),
+ nn.MaxPool2d(2, 2),
+ nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
- nn.BatchNorm2d(64),
+ nn.MaxPool2d(2,2),
Flatten(),
- nn.Linear(64, args.n_way)
- ).to(device)
-
+ nn.Linear(64, args.n_way)).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
@@ -90,6 +90,7 @@ def train(db, net, device, meta_opt, epoch, log):
net.train()
n_train_iter = db.x_train.shape[0] // db.batchsz
+
for batch_idx in range(n_train_iter):
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
@@ -107,7 +108,7 @@ def train(db, net, device, meta_opt, epoch, log):
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 1
- inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1)
+ inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
@@ -167,6 +168,7 @@ def test(db, net, device, epoch, log):
qry_losses = []
qry_accs = []
+
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
@@ -180,8 +182,8 @@ def test(db, net, device, epoch, log):
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
- n_inner_iter = 3
- inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1)
+ n_inner_iter = 5
+ inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt) as (fnet, diffopt):
@@ -195,7 +197,7 @@ def test(db, net, device, epoch, log):
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
- qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none').detach()
qry_losses.append(qry_loss)
qry_accs.append(
qry_logits.argmax(dim=1) == y_qry[i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment