Last active
January 25, 2021 14:15
-
-
Save phizaz/9eb7f39f3ef2fd1b9270e5d7d0e66037 to your computer and use it in GitHub Desktop.
li2018
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from segmentation_models_pytorch.encoders import get_encoder | |
from .mil import MILPool | |
def make_net_li2018( | |
backbone, | |
n_out, | |
n_in=1, | |
n_dec_ch=512, | |
out_size=20, | |
pooling='milpool', | |
min_val=0.98, | |
pretrain='imagenet', | |
**kwargs, | |
): | |
name = f'li2018,out{out_size}-{backbone}-{pooling}-out{n_out}' | |
if n_in != 1: | |
name += f'in{n_in}' | |
if pretrain: | |
name += f'-pretrain{pretrain}' | |
if pooling == 'milpool': | |
if min_val is not None: | |
name += f',min{min_val}' | |
@rename(name) | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.net = get_encoder( | |
name=backbone, | |
in_channels=n_in, | |
weights=pretrain, | |
) | |
self.out = nn.Sequential( | |
nn.UpsamplingBilinear2d((out_size, out_size)), | |
nn.Conv2d(self.net.out_channels[-1], n_dec_ch, 3, padding=1), | |
nn.BatchNorm2d(n_dec_ch), | |
nn.ReLU(), | |
nn.Conv2d(n_dec_ch, n_out, 1, bias=True), | |
) | |
pooling_opts = { | |
'maxpool': nn.AdaptiveMaxPool2d(1), | |
'avgpool': nn.AdaptiveAvgPool2d(1), | |
'milpool': MILPool(min_val=min_val, apply_sigmoid=True, ret_logit=True), | |
} | |
self.pool = pooling_opts[pooling] | |
def forward(self, x): | |
# select the last layer | |
x = self.net(x)[-1] | |
seg = self.out(x).float() | |
pred = self.pool(seg) | |
pred = torch.flatten(pred, 1) | |
return { | |
'pred': pred, | |
'seg': seg, | |
} | |
return Net |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mil_output(p, min_val): | |
""" | |
Args: | |
min_val: cap the min value of 1-p to prevent underflow | |
""" | |
n, c, _, _ = p.shape | |
not_p = 1 - p | |
not_p = (1-min_val) * not_p + min_val | |
not_p = not_p.view(n, c, -1).float() | |
pred = 1 - torch.prod(not_p, dim=-1, keepdim=True) | |
pred = pred.view(n, c, 1, 1) | |
return pred | |
class MILPool(nn.Module): | |
""" | |
Multi-instance pooling: | |
The output is positive when there is at least one positive patch | |
Found in: | |
Li, Zhe, Chong Wang, Mei Han, Yuan Xue, Wei Wei, Li-Jia Li, and Li Fei-Fei. 2018. | |
“Thoracic Disease Identification and Localization with Limited Supervision.” | |
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 8290–99. | |
Args: | |
min_val: None = auto | |
ret_logit: returns as logit (not prob), to keep the interface invariance | |
""" | |
def __init__(self, min_val=0.98, apply_sigmoid=True, ret_logit=False): | |
super().__init__() | |
self.min_val = min_val | |
self.apply_sigmoid = apply_sigmoid | |
self.ret_logit = ret_logit | |
def forward(self, x): | |
n, c, h, w = x.shape | |
min_val = self.min_val | |
if self.apply_sigmoid: | |
x = torch.sigmoid(x) | |
pred = mil_output(x, min_val=min_val) | |
if self.ret_logit: | |
# logit function inverses the sigmoid | |
pred = torch.log(pred / (1-pred)) | |
return pred |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment