Created
September 4, 2019 02:46
-
-
Save zhanghang1989/68b94e79892d420eeabf6e25edae8133 to your computer and use it in GitHub Desktop.
Co-occurrent Features in Semantic Segmentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
########################################################################### | |
# Created by: Hang Zhang | |
# Email: [email protected] | |
# Copyright (c) 2018 | |
########################################################################### | |
from __future__ import division | |
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn.functional import interpolate | |
from .base import BaseNet | |
from ..nn import ACFModule, ConcurrentModule, SyncBatchNorm | |
from .fcn import FCNHead | |
from .encnet import EncModule | |
__all__ = ['ATTEN', 'get_atten'] | |
class ATTEN(BaseNet): | |
def __init__(self, nclass, backbone, nheads=8, nmixs=1, with_global=True, | |
with_enc=True, with_lateral=False, aux=True, se_loss=False, | |
norm_layer=SyncBatchNorm, **kwargs): | |
super(ATTEN, self).__init__(nclass, backbone, aux, se_loss, | |
norm_layer=norm_layer, **kwargs) | |
in_channels = 4096 if self.backbone.startswith('wideresnet') else 2048 | |
self.head = ATTENHead(in_channels, nclass, norm_layer, self._up_kwargs, | |
nheads=nheads, nmixs=nmixs, with_global=with_global, | |
with_enc=with_enc, se_loss=se_loss, | |
lateral=with_lateral) | |
if aux: | |
self.auxlayer = FCNHead(1024, nclass, norm_layer) | |
def forward(self, x): | |
imsize = x.size()[2:] | |
#_, _, c3, c4 = self.base_forward(x) | |
#x = list(self.head(c4)) | |
features = self.base_forward(x) | |
x = list(self.head(*features)) | |
x[0] = interpolate(x[0], imsize, **self._up_kwargs) | |
if self.aux: | |
#auxout = self.auxlayer(c3) | |
auxout = self.auxlayer(features[2]) | |
auxout = interpolate(auxout, imsize, **self._up_kwargs) | |
x.append(auxout) | |
return tuple(x) | |
def demo(self, x): | |
imsize = x.size()[2:] | |
features = self.base_forward(x) | |
return self.head.demo(*features) | |
class GlobalPooling(nn.Module): | |
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs): | |
super(GlobalPooling, self).__init__() | |
self._up_kwargs = up_kwargs | |
self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(in_channels, out_channels, 1, bias=False), | |
norm_layer(out_channels), | |
nn.ReLU(True)) | |
def forward(self, x): | |
_, _, h, w = x.size() | |
pool = self.gap(x) | |
return interpolate(pool, (h,w), **self._up_kwargs) | |
class ATTENHead(nn.Module): | |
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs, | |
nheads, nmixs, with_global, | |
with_enc, se_loss, lateral): | |
super(ATTENHead, self).__init__() | |
self.with_enc = with_enc | |
self.se_loss = se_loss | |
self._up_kwargs = up_kwargs | |
inter_channels = in_channels // 4 | |
self.lateral = lateral | |
self.conv5 = nn.Sequential( | |
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), | |
norm_layer(inter_channels), | |
nn.ReLU()) | |
if lateral: | |
self.connect = nn.ModuleList([ | |
nn.Sequential( | |
nn.Conv2d(512, 512, kernel_size=1, bias=False), | |
norm_layer(512), | |
nn.ReLU(inplace=True)), | |
nn.Sequential( | |
nn.Conv2d(1024, 512, kernel_size=1, bias=False), | |
norm_layer(512), | |
nn.ReLU(inplace=True)), | |
]) | |
self.fusion = nn.Sequential( | |
nn.Conv2d(3*512, 512, kernel_size=3, padding=1, bias=False), | |
norm_layer(512), | |
nn.ReLU(inplace=True)) | |
extended_channels = 0 | |
self.atten = ACFModule(nheads, nmixs, inter_channels, inter_channels//nheads*nmixs, | |
inter_channels//nheads, norm_layer) | |
if with_global: | |
extended_channels = inter_channels | |
self.atten_layers = ConcurrentModule([ | |
GlobalPooling(inter_channels, extended_channels, norm_layer, self._up_kwargs), | |
self.atten, | |
#nn.Sequential(*atten), | |
]) | |
else: | |
self.atten_layers = nn.Sequential(*atten) | |
if with_enc: | |
self.encmodule = EncModule(inter_channels+extended_channels, out_channels, ncodes=32, | |
se_loss=se_loss, norm_layer=norm_layer) | |
self.conv6 = nn.Sequential(nn.Dropout2d(0.1, False), | |
nn.Conv2d(inter_channels+extended_channels, out_channels, 1)) | |
def forward(self, *inputs): | |
feat = self.conv5(inputs[-1]) | |
if self.lateral: | |
c2 = self.connect[0](inputs[1]) | |
c3 = self.connect[1](inputs[2]) | |
feat = self.fusion(torch.cat([feat, c2, c3], 1)) | |
feat = self.atten_layers(feat) | |
if self.with_enc: | |
outs = list(self.encmodule(feat)) | |
else: | |
outs = [feat] | |
outs[0] = self.conv6(outs[0]) | |
return tuple(outs) | |
def demo(self, *inputs): | |
feat = self.conv5(inputs[-1]) | |
if self.lateral: | |
c2 = self.connect[0](inputs[1]) | |
c3 = self.connect[1](inputs[2]) | |
feat = self.fusion(torch.cat([feat, c2, c3], 1)) | |
attn = self.atten.demo(feat) | |
return attn | |
def get_atten(dataset='pascal_voc', backbone='resnet50', pretrained=False, | |
root='~/.encoding/models', **kwargs): | |
r"""ATTEN model from the paper `"Fully Convolutional Network for semantic segmentation" | |
<https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_atten.pdf>`_ | |
Parameters | |
---------- | |
dataset : str, default pascal_voc | |
The dataset that model pretrained on. (pascal_voc, ade20k) | |
pretrained : bool, default False | |
Whether to load the pretrained weights for model. | |
pooling_mode : str, default 'avg' | |
Using 'max' pool or 'avg' pool in the Attention module. | |
root : str, default '~/.encoding/models' | |
Location for keeping the model parameters. | |
Examples | |
-------- | |
>>> model = get_atten(dataset='pascal_voc', backbone='resnet50', pretrained=False) | |
>>> print(model) | |
""" | |
# infer number of classes | |
from ..datasets import datasets, acronyms | |
model = ATTEN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) | |
if pretrained: | |
from .model_store import get_model_file | |
model.load_state_dict(torch.load( | |
get_model_file('atten_%s_%s'%(backbone, acronyms[dataset]), root=root))) | |
return model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
########################################################################### | |
# Created by: Hang Zhang | |
# Email: [email protected] | |
# Copyright (c) 2018 | |
########################################################################### | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from .syncbn import SyncBatchNorm | |
__all__ = ['ACFModule', 'MixtureOfSoftMaxACF'] | |
class ACFModule(nn.Module): | |
""" Multi-Head Attention module """ | |
def __init__(self, n_head, n_mix, d_model, d_k, d_v, norm_layer=SyncBatchNorm, | |
kq_transform='conv', value_transform='conv', | |
pooling=True, concat=False, dropout=0.1): | |
super(ACFModule, self).__init__() | |
self.n_head = n_head | |
self.n_mix = n_mix | |
self.d_k = d_k | |
self.d_v = d_v | |
self.pooling = pooling | |
self.concat = concat | |
if self.pooling: | |
self.pool = nn.AvgPool2d(3, 2, 1, count_include_pad=False) | |
if kq_transform == 'conv': | |
self.conv_qs = nn.Conv2d(d_model, n_head*d_k, 1) | |
nn.init.normal_(self.conv_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
elif kq_transform == 'ffn': | |
self.conv_qs = nn.Sequential( | |
nn.Conv2d(d_model, n_head*d_k, 3, padding=1, bias=False), | |
norm_layer(n_head*d_k), | |
nn.ReLU(True), | |
nn.Conv2d(n_head*d_k, n_head*d_k, 1), | |
) | |
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k)) | |
elif kq_transform == 'dffn': | |
self.conv_qs = nn.Sequential( | |
nn.Conv2d(d_model, n_head*d_k, 3, padding=4, dilation=4, bias=False), | |
norm_layer(n_head*d_k), | |
nn.ReLU(True), | |
nn.Conv2d(n_head*d_k, n_head*d_k, 1), | |
) | |
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k)) | |
else: | |
raise NotImplemented | |
#self.conv_ks = nn.Conv2d(d_model, n_head*d_k, 1) | |
self.conv_ks = self.conv_qs | |
if value_transform == 'conv': | |
self.conv_vs = nn.Conv2d(d_model, n_head*d_v, 1) | |
else: | |
raise NotImplemented | |
#nn.init.normal_(self.conv_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
nn.init.normal_(self.conv_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) | |
self.attention = MixtureOfSoftMaxACF(n_mix=n_mix, d_k=d_k) | |
self.conv = nn.Conv2d(n_head*d_v, d_model, 1, bias=False) | |
self.norm_layer = norm_layer(d_model) | |
def forward(self, x): | |
residual = x | |
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |
b_, c_, h_, w_ = x.size() | |
if self.pooling: | |
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_) | |
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4) | |
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4) | |
else: | |
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_) | |
qt = kt | |
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_) | |
output, attn = self.attention(qt, kt, vt) | |
output = output.transpose(1, 2).contiguous().view(b_, n_head*d_v, h_, w_) | |
output = self.conv(output) | |
if self.concat: | |
output = torch.cat((self.norm_layer(output), residual), 1) | |
else: | |
output = self.norm_layer(output) + residual | |
return output | |
def demo(self, x): | |
residual = x | |
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |
b_, c_, h_, w_ = x.size() | |
if self.pooling: | |
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_) | |
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4) | |
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4) | |
else: | |
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_) | |
qt = kt | |
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_) | |
_, attn = self.attention(qt, kt, vt) | |
attn.view(b_, n_head, h_*w_, -1) | |
return attn | |
def extra_repr(self): | |
return 'n_head={}, n_mix={}, d_k={}, pooling={}' \ | |
.format(self.n_head, self.n_mix, self.d_k, self.pooling) | |
class MixtureOfSoftMaxACF(nn.Module): | |
""""Mixture of SoftMax""" | |
def __init__(self, n_mix, d_k, attn_dropout=0.1): | |
super(MixtureOfSoftMaxACF, self).__init__() | |
self.temperature = np.power(d_k, 0.5) | |
self.n_mix = n_mix | |
self.att_drop = attn_dropout | |
self.dropout = nn.Dropout(attn_dropout) | |
self.softmax1 = nn.Softmax(dim=1) | |
self.softmax2 = nn.Softmax(dim=2) | |
self.d_k = d_k | |
if n_mix > 1: | |
self.weight = nn.Parameter(torch.Tensor(n_mix, d_k)) | |
std = np.power(n_mix, -0.5) | |
self.weight.data.uniform_(-std, std) | |
def forward(self, qt, kt, vt): | |
B, d_k, N = qt.size() | |
m = self.n_mix | |
assert d_k == self.d_k | |
d = d_k // m | |
if m > 1: | |
# \bar{v} \in R^{B, d_k, 1} | |
bar_qt = torch.mean(qt, 2, True) | |
# pi \in R^{B, m, 1} | |
pi = self.softmax1(torch.matmul(self.weight, bar_qt)).view(B*m, 1, 1) | |
# reshape for n_mix | |
q = qt.view(B*m, d, N).transpose(1, 2) | |
N2 = kt.size(2) | |
kt = kt.view(B*m, d, N2) | |
v = vt.transpose(1, 2) | |
# {Bm, N, N} | |
attn = torch.bmm(q, kt) | |
attn = attn / self.temperature | |
attn = self.softmax2(attn) | |
attn = self.dropout(attn) | |
if m > 1: | |
# attn \in R^{Bm, N, N2} => R^{B, N, N2} | |
attn = (attn * pi).view(B, m, N, N2).sum(1) | |
output = torch.bmm(attn, v) | |
return output, attn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment