Created
October 6, 2017 02:22
-
-
Save ruotianluo/043fae22e9f8fd1b36b82189f2356937 to your computer and use it in GitHub Desktop.
A snippet to show how roialign works
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import numpy as np | |
def pyth_crop_pool_layer(bottom, rois, pool_size): | |
x1 = rois[:, 1::4] / 16.0 | |
y1 = rois[:, 2::4] / 16.0 | |
x2 = rois[:, 3::4] / 16.0 | |
y2 = rois[:, 4::4] / 16.0 | |
height = bottom.size(2) | |
width = bottom.size(3) | |
# affine theta | |
zero = rois.data.new(rois.size(0), 1).zero_() | |
theta = torch.cat([\ | |
(x2 - x1) / (width - 1), | |
zero, | |
(x1 + x2 - width + 1) / (width - 1), | |
zero, | |
(y2 - y1) / (height - 1), | |
(y1 + y2 - height + 1) / (height - 1)], 1).view(-1, 2, 3) | |
grid = F.affine_grid(theta, torch.Size((rois.size(0), 1, pool_size, pool_size))) | |
crops = F.grid_sample(bottom.expand(rois.size(0), bottom.size(1), bottom.size(2), bottom.size(3)), grid) | |
return crops | |
bottom = Variable(torch.arange(0,49).view(1,1,7,7)) | |
rois = Variable(torch.Tensor([[0,0,0,6,6], [0,0,1,3,6]])) * 16 # ignore rois[:, 0] | |
pool_size = 4 | |
print(pyth_crop_pool_layer(bottom, rois, pool_size)) | |
print(bottom[0:7,0:7]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the gist!
For sure it could be a bug on my configuration, but to make this work I had to change line 23 to:
torch.cat was complaining about having a list as 0 argument.