Last active
April 2, 2019 08:58
-
-
Save lartpang/77cf8495374dca9283c62f2f4feffd99 to your computer and use it in GitHub Desktop.
对3DGNN中使用的GNN的一些改动
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.modules.utils import _pair, _quadruple | |
class MedianPool2d(nn.Module): | |
""" Median pool (usable as median filter when stride=1) module. | |
Args: | |
kernel_size: size of pooling kernel, int or 2-tuple | |
stride: pool stride, int or 2-tuple | |
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad | |
same: override padding and enforce same padding, boolean | |
""" | |
def __init__(self, kernel_size=3, stride=1, padding=0, same=False): | |
super(MedianPool2d, self).__init__() | |
self.k = _pair(kernel_size) | |
self.stride = _pair(stride) | |
self.padding = _quadruple(padding) # convert to l, r, g, b | |
self.same = same | |
def _padding(self, x): | |
if self.same: | |
ih, iw = x.size()[2:] | |
if ih % self.stride[0] == 0: | |
ph = max(self.k[0] - self.stride[0], 0) | |
else: | |
ph = max(self.k[0] - (ih % self.stride[0]), 0) | |
if iw % self.stride[1] == 0: | |
pw = max(self.k[1] - self.stride[1], 0) | |
else: | |
pw = max(self.k[1] - (iw % self.stride[1]), 0) | |
pl = pw // 2 | |
pr = pw - pl | |
pt = ph // 2 | |
pb = ph - pt | |
padding = (pl, pr, pt, pb) | |
else: | |
padding = self.padding | |
return padding | |
def forward(self, x): | |
# using existing pytorch functions and tensor ops so that we get autograd, | |
# would likely be more efficient to implement from scratch at C/Cuda level | |
x = F.pad(x, self._padding(x), mode='reflect') | |
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], | |
self.stride[1]) | |
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] | |
return x | |
class GNNV1(nn.Module): | |
def __init__(self, mlp_num, k, gnn_iterations, device): | |
super(GNNV1, self).__init__() | |
self.k = k | |
self.device = device | |
self.gnn_iternum = gnn_iterations | |
self.median_pool = MedianPool2d( | |
kernel_size=32, stride=32, padding=0, same=False) | |
self.g_rnn_layers = nn.ModuleList( | |
[nn.Linear(2048, 2048) for l in range(mlp_num)]) | |
self.g_rnn_actfs = nn.ModuleList( | |
[nn.ReLU(inplace=True) for l in range(mlp_num)]) | |
self.q_rnn_layer = nn.Linear(4096, 2048) | |
self.q_rnn_actf = nn.ReLU(inplace=True) | |
self.output_conv = nn.Conv2d(4096, 2048, 3, stride=1, padding=1, | |
bias=True) | |
self.device = device | |
def forward(self, cnn_encoder_output, original_input): | |
""" | |
:param cnn_encoder_output: CNN编码器输出 | |
:param original_input: 原始图像输入 | |
:param gnn_iterations: GNN迭代次数 | |
:param k: K邻近聚类数 | |
""" | |
# extract for convenience | |
N, C, H, W = cnn_encoder_output.size() | |
K = self.k | |
# 这里为了获得三维空间中的位置坐标,这里使用RGB色彩空间的标定 | |
proj_3d = self.median_pool(original_input) # N 3 H W | |
# N H*W 3 | |
proj_3d = proj_3d.view(N, 3, (H * W)).transpose(2, 1).contiguous() | |
# get k nearest neighbors | |
knn = self.__get_knn_indices(proj_3d) # N HW K | |
knn = knn.view(N, H * W * K).long() # N HWK | |
# prepare CNN encoded features for RNN | |
h = cnn_encoder_output # N C H W | |
# 调整维度之后, 一般需要在contiguous后才能用view | |
h = h.permute(0, 2, 3, 1).contiguous() # N H W C | |
h = h.view(N, (H * W), C) # N HW C | |
# aggregate and iterate messages in m, keep original CNN features h for later | |
m = h.clone() # N HW C | |
# loop over timestamps to unroll | |
for i in range(self.gnn_iternum): | |
# do this for every sample in batch, not nice, but I don't know | |
# how to use index_select batchwise | |
# todo: 这里可以考虑提速, 使用batch级别的索引选择 | |
for n in range(N): | |
# fetch features from nearest neighbors | |
# 从各个batch的邻居中获取信息, 这里允许重复索引 | |
neighbor_f = torch.index_select( | |
h[n], 0, knn[n]).view(H * W, K, C) # HW K C | |
# run neighbor features through MLP g and activation function | |
# todo: 多层MLP的处理, 是否可以替换成卷及操作? | |
# 使用多个线性层+ReLU | |
for g_line, g_actf in zip(self.g_rnn_layers, self.g_rnn_actfs): | |
# 对每个HW中的点的KxC向量表示的邻域信息进行加权计算 | |
# 线性层(全连接)要求的输入是HWxKxCin=>HWxKxCout | |
neighbor_f = g_line(neighbor_f) | |
neighbor_f = g_actf(neighbor_f) | |
# HW K C | |
# average over activated neighbors | |
m[n] = torch.mean(neighbor_f, dim=1) # HW C | |
# concatenate current state with messages | |
concat = torch.cat((h, m), 2) # N HW 2C | |
# get new features by running MLP q and activation function | |
h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C | |
# format RNN activations back to image, concatenate original CNN embedding, return | |
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W | |
output = self.output_conv( | |
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W | |
return output | |
def __get_knn_indices(self, batch_mat): | |
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1)) | |
N, HW, _ = r.size() | |
batch_indices = torch.zeros((N, HW, self.k)).to(self.device) | |
for idx, val in enumerate(r): | |
# get the diagonal elements | |
diag = val.diag().unsqueeze(0) | |
diag = diag.expand_as(val) | |
# compute the distance matrix | |
D = (diag + diag.t() - 2 * val).sqrt() | |
topk, indices = torch.topk(D, k=self.k, largest=False) | |
batch_indices[idx] = indices.data | |
return batch_indices | |
class GNNV2(nn.Module): | |
def __init__(self, mlp_num, k, gnn_iterations, device): | |
super(GNNV2, self).__init__() | |
self.k = k | |
self.device = device | |
self.gnn_iternum = gnn_iterations | |
self.median_pool = MedianPool2d( | |
kernel_size=32, stride=32, padding=0, same=False) | |
g_rnn_conv_list = [ | |
nn.Conv2d(2048 * self.k, 2048 * self.k, 1), | |
nn.BatchNorm2d(2048 * self.k), | |
nn.ReLU(inplace=True) | |
] * mlp_num | |
self.g_rnn_conv = nn.Sequential(*g_rnn_conv_list) | |
self.q_rnn_layer = nn.Linear(4096, 2048) | |
self.q_rnn_actf = nn.ReLU(inplace=True) | |
self.output_conv = nn.Conv2d(4096, 2048, 3, stride=1, padding=1, | |
bias=True) | |
def forward(self, cnn_encoder_output, original_input): | |
""" | |
:param cnn_encoder_output: CNN编码器输出 | |
:param original_input: 原始图像输入 | |
:param gnn_iterations: GNN迭代次数 | |
:param k: K邻近聚类数 | |
""" | |
# extract for convenience | |
N, C, H, W = cnn_encoder_output.size() | |
K = self.k | |
# 这里为了获得三维空间中的位置坐标,这里使用RGB色彩空间的标定 | |
proj_3d = self.median_pool(original_input) # N 3 H W | |
# N H*W 3 | |
proj_3d = proj_3d.view(N, 3, (H * W)).transpose(2, 1).contiguous() | |
# get k nearest neighbors | |
knn = self.__get_knn_indices(proj_3d) # N HW K | |
knn = knn.view(N * H * W * K).long() # NHWK | |
# prepare CNN encoded features for RNN | |
h = cnn_encoder_output # N C H W | |
# 调整维度之后, 一般需要在contiguous后才能用view | |
h = h.permute(0, 2, 3, 1).contiguous() # N H W C | |
# loop over timestamps to unroll | |
for i in range(self.gnn_iternum): | |
# do this for every sample in batch, not nice, but I don't know | |
# how to use index_select batchwise | |
# fetch features from nearest neighbors | |
# 从各个batch的邻居中获取信息, 这里允许重复索引 | |
# N H W K*C | |
h = h.view(N * (H * W), C) # NHW C | |
neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C) | |
neighbor_f = neighbor_f.permute(0, 3, 1, 2) | |
neighbor_f = self.g_rnn_conv(neighbor_f) | |
neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous() # N H W KC | |
neighbor_f = neighbor_f.view(N, H * W, K, C) | |
# aggregate and iterate messages in m, keep original CNN features h for later | |
m = torch.mean(neighbor_f, dim=2) | |
h = h.view(N, (H * W), C) | |
# concatenate current state with messages | |
concat = torch.cat((h, m), 2) # N HW 2C | |
# get new features by running MLP q and activation function | |
h = self.q_rnn_actf(self.q_rnn_layer(concat)) # N HW C | |
# format RNN activations back to image, concatenate original CNN embedding, return | |
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W | |
output = self.output_conv( | |
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W | |
return output | |
# adapted from | |
# https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6 | |
# (x - y)^2 = x^2 - 2*x*y + y^2 | |
def __get_knn_indices(self, batch_mat): | |
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1)) | |
N, HW, _ = r.size() | |
batch_indices = torch.zeros((N, HW, self.k)).to(self.device) | |
for idx, val in enumerate(r): | |
# get the diagonal elements | |
diag = val.diag().unsqueeze(0) | |
diag = diag.expand_as(val) | |
# compute the distance matrix | |
D = (diag + diag.t() - 2 * val).sqrt() | |
topk, indices = torch.topk(D, k=self.k, largest=False) | |
batch_indices[idx] = indices.data | |
return batch_indices | |
class GNNV3(nn.Module): | |
def __init__(self, mlp_num, k, gnn_iterations, device): | |
super(GNNV3, self).__init__() | |
self.k = k | |
self.device = device | |
self.gnn_iternum = gnn_iterations | |
self.median_pool = MedianPool2d( | |
kernel_size=32, stride=32, padding=0, same=False) | |
g_rnn_conv_list = [ | |
nn.Conv2d(2048 * self.k, 2048 * self.k, 1), | |
nn.BatchNorm2d(2048 * self.k), | |
nn.ReLU(inplace=True) | |
] * mlp_num | |
self.g_rnn_conv = nn.Sequential(*g_rnn_conv_list) | |
self.q_rnn_conv = nn.Sequential( | |
nn.Conv2d(4096, 2048, 1), | |
nn.BatchNorm2d(2048), | |
nn.ReLU(inplace=True) | |
) | |
self.output_conv = nn.Conv2d(4096, 2048, 3, | |
stride=1, | |
padding=1, | |
bias=True) | |
def forward(self, cnn_encoder_output, original_input): | |
""" | |
:param cnn_encoder_output: CNN编码器输出 | |
:param original_input: 原始图像输入 | |
:param gnn_iterations: GNN迭代次数 | |
:param k: K邻近聚类数 | |
""" | |
# extract for convenience | |
N, C, H, W = cnn_encoder_output.size() | |
K = self.k | |
# 这里为了获得三维空间中的位置坐标,这里使用RGB色彩空间的标定 | |
proj_3d = self.median_pool(original_input) # N 3 H W | |
# N H*W 3 | |
proj_3d = proj_3d.view(N, 3, (H * W)).transpose(2, 1).contiguous() | |
# get k nearest neighbors | |
knn = self.__get_knn_indices(proj_3d) # N HW K | |
knn = knn.view(N * H * W * K).long() # NHWK | |
# prepare CNN encoded features for RNN | |
h = cnn_encoder_output # N C H W | |
# 调整维度之后, 一般需要在contiguous后才能用view | |
h = h.permute(0, 2, 3, 1).contiguous() # N H W C | |
# loop over timestamps to unroll | |
for i in range(self.gnn_iternum): | |
# do this for every sample in batch, not nice, but I don't know | |
# how to use index_select batchwise | |
# fetch features from nearest neighbors | |
# 从各个batch的邻居中获取信息, 这里允许重复索引 | |
# N H W K*C | |
h = h.view(N * (H * W), C) # NHW C | |
neighbor_f = torch.index_select(h, 0, knn).view(N, H, W, K * C) | |
neighbor_f = neighbor_f.permute(0, 3, 1, 2) # N KC H W | |
neighbor_f = self.g_rnn_conv(neighbor_f) | |
neighbor_f = neighbor_f.permute(0, 2, 3, 1).contiguous() # N H W KC | |
neighbor_f = neighbor_f.view(N, H * W, K, C) | |
# aggregate and iterate messages in m, keep original CNN features h for later | |
m = torch.mean(neighbor_f, dim=2) | |
h = h.view(N, (H * W), C) | |
# concatenate current state with messages | |
concat = torch.cat((h, m), 2).view(N, H, W, 2 * C) # N HW 2C | |
concat = concat.permute(0, 3, 1, 2) | |
# get new features by running MLP q and activation function | |
h = self.q_rnn_conv(concat) # N, C, H, W | |
h = h.permute(0, 2, 3, 1).contiguous() # N H W C | |
# format RNN activations back to image, concatenate original CNN embedding, return | |
h = h.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() # N C H W | |
output = self.output_conv( | |
torch.cat((cnn_encoder_output, h), 1)) # N 2C H W | |
return output | |
# adapted from | |
# https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/6 | |
# (x - y)^2 = x^2 - 2*x*y + y^2 | |
def __get_knn_indices(self, batch_mat): | |
r = torch.bmm(batch_mat, batch_mat.permute(0, 2, 1)) | |
N, HW, _ = r.size() | |
batch_indices = torch.zeros((N, HW, self.k)).to(self.device) | |
for idx, val in enumerate(r): | |
# get the diagonal elements | |
diag = val.diag().unsqueeze(0) | |
diag = diag.expand_as(val) | |
# compute the distance matrix | |
D = (diag + diag.t() - 2 * val).sqrt() | |
topk, indices = torch.topk(D, k=self.k, largest=False) | |
batch_indices[idx] = indices.data | |
return batch_indices | |
if __name__ == '__main__': | |
device = torch.device('cuda:0') | |
gnn = GNNV2(3, k=12, gnn_iterations=3, device=device).to(device) | |
cnn_encoder_output = torch.randint(0, 255, size=(5, 2048, 7, 7), | |
dtype=torch.float32).to(device) | |
original_input = torch.randint(0, 255, size=(5, 3, 224, 224), | |
dtype=torch.float32).to(device) | |
gnn_iterations = 1 | |
import time | |
start = time.time() | |
output = gnn(cnn_encoder_output, | |
original_input) | |
print(output.size()) | |
# v1 0.054094552993774414 | |
# v2 0.01854419708251953 | |
# v3 0.02263331413269043 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment