Created
August 12, 2022 11:03
-
-
Save ShoufaChen/263eaf55599c6e884584d7fce445af45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Modified by Shoufa Chen, | |
import math | |
import random | |
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fvcore.nn import sigmoid_focal_loss_jit | |
from slowfast.models.losses import focal_loss_wo_logits_jit | |
from detectron2.modeling.poolers import ROIPooler | |
from detectron2.structures import Boxes | |
from slowfast.datasets.cv2_transform import clip_boxes_tensor | |
_DEFAULT_SCALE_CLAMP = math.log(100000.0 / 16) | |
class ResNetRoIHead(nn.Module): | |
""" | |
ResNe(X)t RoI head. | |
""" | |
def __init__( | |
self, | |
cfg, | |
dim_in, | |
num_classes, | |
pool_size, | |
resolution, | |
scale_factor, | |
dropout_rate=0.0, | |
act_func="softmax", | |
aligned=True, | |
dim_before_proj=2048, | |
use_fpn=False, | |
): | |
""" | |
The `__init__` method of any subclass should also contain these | |
arguments. | |
ResNetRoIHead takes p pathways as input where p in [1, infty]. | |
Args: | |
dim_in (list): the list of channel dimensions of the p inputs to the | |
ResNetHead. | |
num_classes (int): the channel dimensions of the p outputs to the | |
ResNetHead. | |
pool_size (list): the list of kernel sizes of p spatial temporal | |
poolings, temporal pool kernel size, spatial pool kernel size, | |
spatial pool kernel size in order. | |
resolution (list): the list of spatial output size from the ROIAlign. | |
scale_factor (list): the list of ratio to the input boxes by this | |
number. | |
dropout_rate (float): dropout rate. If equal to 0.0, perform no | |
dropout. | |
act_func (string): activation function to use. 'softmax': applies | |
softmax on the output. 'sigmoid': applies sigmoid on the output. | |
aligned (bool): if False, use the legacy implementation. If True, | |
align the results more perfectly. | |
Note: | |
Given a continuous coordinate c, its two neighboring pixel indices | |
(in our pixel model) are computed by floor (c - 0.5) and ceil | |
(c - 0.5). For example, c=1.3 has pixel neighbors with discrete | |
indices [0] and [1] (which are sampled from the underlying signal at | |
continuous coordinates 0.5 and 1.5). But the original roi_align | |
(aligned=False) does not subtract the 0.5 when computing neighboring | |
pixel indices and therefore it uses pixels with a slightly incorrect | |
alignment (relative to our pixel model) when performing bilinear | |
interpolation. | |
With `aligned=True`, we first appropriately scale the ROI and then | |
shift it by -0.5 prior to calling roi_align. This produces the | |
correct neighbors; It makes negligible differences to the model's | |
performance if ROIAlign is used together with conv layers. | |
""" | |
super(ResNetRoIHead, self).__init__() | |
assert ( | |
len({len(pool_size), len(dim_in)}) == 1 | |
), "pathway dimensions are not consistent." | |
self.cfg = cfg | |
self.use_fpn = use_fpn | |
self.gt_boxes_prob = cfg.MODEL.SparseRCNN.GT_BOXES_PROB | |
self.num_pathways = len(pool_size) | |
self.device = torch.device(cfg.MODEL.DEVICE) | |
self.use_action_heads = cfg.MODEL.SparseRCNN.NUM_ACT_HEADS > 0 | |
# move conv1x1 dim_before_proj-->256 from backbone to head | |
if self.use_action_heads: | |
self.proj_to_256 = nn.Conv3d(dim_before_proj, 256, kernel_size=1) | |
for pathway in range(self.num_pathways): | |
temporal_pool = nn.AvgPool3d( | |
[pool_size[pathway][0], 1, 1], stride=1 | |
) | |
self.add_module("s{}_tpool".format(pathway), temporal_pool) | |
pooler = ROIPooler( | |
output_size=resolution[pathway], | |
scales=[1.0 / scale_factor[pathway]], | |
sampling_ratio=2, | |
pooler_type=cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE | |
) | |
self.add_module("s{}_roi".format(pathway), pooler) | |
if self.use_fpn: | |
keyframe_pooler = self._init_box_pooler(cfg) | |
self.add_module("s{}_keyroi".format(pathway), keyframe_pooler) | |
if pathway == 0: | |
rcnn_head = RCNNHead(cfg) | |
head_series = _get_clones(rcnn_head, cfg.MODEL.SparseRCNN.NUM_HEADS) | |
self.add_module("s{}_headseries".format(pathway), head_series) | |
if self.use_action_heads: | |
act_rcnn_head = RCNNHead(cfg, origin=False) | |
act_head_series = _get_clones(act_rcnn_head, cfg.MODEL.SparseRCNN.NUM_ACT_HEADS) | |
self.add_module("s{}_actheadseries".format(pathway), act_head_series) | |
temp_head = RCNNHead3D(cfg) | |
temp_head_series = _get_clones(temp_head, cfg.MODEL.SparseRCNN.NUM_ACT_HEADS) | |
self.add_module("s{}_tempheadseries".format(pathway), temp_head_series) | |
spatial_pool = nn.MaxPool2d(resolution[pathway], stride=1) | |
self.add_module("s{}_spool".format(pathway), spatial_pool) | |
if self.num_pathways == 2 and self.use_action_heads: | |
proj = nn.Conv2d(512, 256, kernel_size=1) | |
self.add_module("concat_proj", proj) | |
self.return_intermediate = cfg.MODEL.SparseRCNN.DEEP_SUPERVISION | |
self.use_focal = cfg.MODEL.SparseRCNN.USE_FOCAL | |
self.num_classes = num_classes | |
if not self.use_action_heads: | |
if dropout_rate > 0: | |
self.dropout = nn.Dropout(dropout_rate) | |
self.ori_projection = nn.Linear(sum(dim_in), num_classes, bias=True) | |
if self.use_focal: | |
prior_prob = cfg.MODEL.SparseRCNN.PRIOR_PROB | |
self.bias_value = -math.log((1 - prior_prob) / prior_prob) | |
self._reset_parameters() | |
@staticmethod | |
def _init_box_pooler(cfg): | |
pooler_resolution = 7 | |
pooler_scales = tuple([1.0 / x for x in [4, 8, 16, 32]]) | |
sampling_ratio = 2 | |
pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE | |
box_pooler = ROIPooler( | |
output_size=pooler_resolution, | |
scales=pooler_scales, | |
sampling_ratio=sampling_ratio, | |
pooler_type=pooler_type, | |
) | |
return box_pooler | |
def _reset_parameters(self): | |
# init all parameters. | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, images_whwh, fpn_features, inputs, init_bboxes, init_features, act_init_features, temp_init_features, criterion=None, targets=None): | |
assert ( | |
len(inputs) == self.num_pathways | |
), "Input tensor does not contain {} pathway".format(self.num_pathways) | |
inter_class_logits = [] | |
inter_action_logits = [] | |
inter_pred_bboxes = [] | |
# reduce | |
feat_pre_reduce = [feat for feat in inputs] | |
if self.use_action_heads: | |
inputs[0] = self.proj_to_256(inputs[0]) | |
bs = len(feat_pre_reduce[0]) # one pathway first dim | |
bboxes = init_bboxes | |
# (100, 256) -> (1, 100 * bs, 256) | |
init_features = init_features[None].repeat(1, bs, 1) | |
proposal_features = init_features.clone() | |
if self.use_action_heads: | |
act_init_features = act_init_features[None].repeat(1, bs, 1) | |
act_proposal_features = act_init_features.clone() | |
temp_init_features = temp_init_features[None].repeat(1, bs, 1) | |
temp_proposal_features = temp_init_features.clone() | |
if self.use_fpn: | |
keyframe = fpn_features | |
else: | |
# we only consider keyframe from Slow pathway | |
assert fpn_features is None, "Check Logic" | |
num_frame = inputs[0].shape[2] | |
if self.cfg.MODEL.SparseRCNN.KEYWAY: | |
raise ValueError("Use FPN feature, or below KEYWAY maybe bug") | |
keyframe = [inputs[0][:, :, -1]] | |
inputs[0] = inputs[0][:, :, :-1] | |
feat_pre_reduce[0] = feat_pre_reduce[0][:, :, :-1] | |
else: | |
keyframe = [inputs[0][:, :, num_frame//2]] | |
pool_out = [] | |
for pathway in range(self.num_pathways): | |
t_pool = getattr(self, "s{}_tpool".format(pathway)) | |
out = t_pool(inputs[pathway] if self.use_action_heads else feat_pre_reduce[pathway]) | |
assert out.shape[2] == 1 | |
out = torch.squeeze(out, 2) | |
pool_out.append(out) | |
if self.use_action_heads: | |
out = torch.cat(pool_out, dim=1) | |
if self.num_pathways == 2: | |
out = self.concat_proj(out) | |
# code below this line, we assume pathway is 0. For SlowFast, we have concat the two pathway into a single way | |
pathway = 0 | |
key_roi_align = getattr(self, "s{}_keyroi".format(pathway)) if self.use_fpn else getattr(self, "s{}_roi".format(pathway)) # noqa | |
for rcnn_head in getattr(self, "s{}_headseries".format(pathway)): | |
class_logits, pred_bboxes, proposal_features, jitter_pred_bboxes = rcnn_head(keyframe, bboxes, proposal_features, key_roi_align, images_whwh=images_whwh) | |
if self.return_intermediate: | |
inter_class_logits.append(class_logits) | |
inter_pred_bboxes.append(pred_bboxes) | |
bboxes = pred_bboxes.detach() | |
if self.cfg.MODEL.SparseRCNN.JITTER_BOX: | |
ava_box = jitter_pred_bboxes.detach() | |
else: | |
ava_box = bboxes | |
roi_align = getattr(self, "s{}_roi".format(pathway)) | |
if self.training: | |
# fork person detector loss, matching indices, idx | |
losses, indices, idx = self.person_detector_loss(inter_class_logits, inter_pred_bboxes, criterion, targets) | |
# Use GT boxes to replace the corresponding position predicted box, with probability self.gt_boxes_prob | |
if random.random() < self.gt_boxes_prob: # random.random() uniform ( 0 inclusive, 1 exclusive) | |
ava_box = ava_box.clone() | |
ava_box[idx] = torch.cat([t['boxes_xyxy'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
if self.use_action_heads: | |
for act_rcnn_head, temp_head in zip(getattr(self, "s{}_actheadseries".format(pathway)), | |
getattr(self, "s{}_tempheadseries".format(pathway))): | |
# for two pathway, we use the fast pathway (inputs[-1]) as the source of temporal feature | |
temp_helper, temp_proposal_features = temp_head(inputs[-1], ava_box, temp_proposal_features, roi_align) | |
action_logits, act_proposal_features = act_rcnn_head([out], ava_box, act_proposal_features, roi_align, temp_helper) | |
if self.return_intermediate: | |
inter_action_logits.append(action_logits) | |
else: | |
N, nr_boxes = bboxes.shape[:2] | |
s_pool_out = [] | |
proposal_boxes = [Boxes(b) for b in bboxes] | |
for i, po in enumerate(pool_out): | |
roi_align = getattr(self, "s{}_roi".format(i)) | |
out = roi_align([po], proposal_boxes) | |
s_pool_out.append(F.adaptive_max_pool2d(out, output_size=(1, 1))) | |
x = torch.cat(s_pool_out, 1) | |
if hasattr(self, "dropout"): | |
x = self.dropout(x) | |
x = x.view(N, nr_boxes, -1) | |
action_logits = self.ori_projection(x) | |
inter_action_logits.append(action_logits) | |
if self.training: | |
if self.cfg.MODEL.SparseRCNN.JHMDB_LOSS: | |
act_loss = self.jhmdb_act_loss(inter_action_logits, targets, indices, idx) | |
else: | |
act_loss = self.action_cls_loss(inter_action_logits, targets, indices, idx) | |
losses.update(act_loss) | |
return losses | |
# eval | |
return dict(pred_logits=inter_class_logits[-1], | |
pred_boxes=inter_pred_bboxes[-1], | |
pred_actions=inter_action_logits[-1]) | |
def person_detector_loss(self, outputs_class, outputs_coord, criterion, targets): | |
if self.return_intermediate: | |
output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], | |
'aux_outputs': [{'pred_logits': a, 'pred_boxes': b} | |
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]} | |
else: | |
raise NotImplementedError | |
loss_dict, indices, idx = criterion(output, targets) | |
return loss_dict, indices, idx | |
def action_cls_loss(self, output_action, targets, indices, idx): | |
losses = {} | |
target_actions_o = torch.cat([t["actions"][J] for t, (_, J) in zip(targets, indices)]) | |
for i, action_logits in enumerate(output_action): | |
action = action_logits[idx] | |
if not self.cfg.MODEL.SparseRCNN.SOFTMAX_POSE: | |
if self.cfg.MODEL.LOSS_FUNC == 'focal_action': | |
act_loss = sigmoid_focal_loss_jit(action, target_actions_o, alpha=0.25, reduction='mean') | |
else: | |
act_loss = F.binary_cross_entropy_with_logits(action, target_actions_o) # remove Sigmoid in model | |
else: | |
pose_pred = F.softmax(action[:, :14], dim=-1) # first 14 is pose label | |
other_pred = F.sigmoid(action[:, 14:]) | |
action = torch.cat([pose_pred, other_pred], dim=-1) | |
if self.cfg.MODEL.LOSS_FUNC == 'focal_action': | |
act_loss = focal_loss_wo_logits_jit(action, target_actions_o, alpha=0.25, reduction='mean') | |
else: | |
act_loss = F.binary_cross_entropy(action, target_actions_o) | |
losses.update({'loss_bce' + f'_{i}': act_loss}) | |
losses['loss_bce'] = losses.pop('loss_bce' + f'_{i}') # modify the last loss key, making AVA meter happy | |
return losses | |
def jhmdb_act_loss(self, output_action, targets, indices, idx): | |
losses = {} | |
target_actions_o = torch.cat([t["actions"][J] for t, (_, J) in zip(targets, indices)]) | |
for i, action_logits in enumerate(output_action): | |
action = action_logits[idx] | |
label = target_actions_o.argmax(dim=1) | |
act_loss = F.cross_entropy(action, label) | |
losses.update({'loss_bce' + f'_{i}': act_loss}) | |
losses['loss_bce'] = losses.pop('loss_bce' + f'_{i}') # modify the last loss key, making AVA meter happy | |
return losses | |
class RCNNHead(nn.Module): | |
def __init__(self, cfg, scale_clamp: float = _DEFAULT_SCALE_CLAMP, | |
bbox_weights=(2.0, 2.0, 1.0, 1.0), origin=True): | |
super().__init__() | |
d_model = cfg.MODEL.SparseRCNN.HIDDEN_DIM | |
num_classes = cfg.MODEL.SparseRCNN.NUM_CLASSES | |
num_actions = cfg.MODEL.NUM_CLASSES | |
dim_feedforward = cfg.MODEL.SparseRCNN.DIM_FEEDFORWARD | |
nhead = cfg.MODEL.SparseRCNN.NHEADS | |
dropout = cfg.MODEL.SparseRCNN.DROPOUT | |
activation = cfg.MODEL.SparseRCNN.ACTIVATION | |
self.d_model = d_model | |
self.jitter_box = cfg.MODEL.SparseRCNN.JITTER_BOX | |
# dynamic. | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.inst_interact = DynamicConv(cfg) | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.norm3 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.is_origin = origin # origin = False for action_proposal_feature head | |
self.combine_mode = cfg.MODEL.SparseRCNN.ST_COMBINE | |
# cls. | |
if self.is_origin: | |
num_cls = cfg.MODEL.SparseRCNN.NUM_CLS | |
cls_module = list() | |
for _ in range(num_cls): | |
cls_module.append(nn.Linear(d_model, d_model, False)) | |
cls_module.append(nn.LayerNorm(d_model)) | |
cls_module.append(nn.ReLU(inplace=True)) | |
self.cls_module = nn.ModuleList(cls_module) | |
# reg. | |
num_reg = cfg.MODEL.SparseRCNN.NUM_REG | |
reg_module = list() | |
for _ in range(num_reg): | |
reg_module.append(nn.Linear(d_model, d_model, False)) | |
reg_module.append(nn.LayerNorm(d_model)) | |
reg_module.append(nn.ReLU(inplace=True)) | |
self.reg_module = nn.ModuleList(reg_module) | |
else: | |
# act. | |
if self.combine_mode == 'MHA': | |
self.st_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
if self.combine_mode == 'concat': | |
fc_concat = 2 | |
else: | |
fc_concat = 1 | |
num_act = cfg.MODEL.SparseRCNN.NUM_ACT | |
assert num_act > 0, "at least 1, but got num_act {}".format(num_act) | |
act_dim = cfg.MODEL.SparseRCNN.ACT_FC_DIM | |
act_module = list() | |
act_module.append(nn.Linear(d_model * fc_concat, act_dim, False)) | |
act_module.append(nn.LayerNorm(act_dim)) | |
act_module.append(nn.ReLU(inplace=True)) | |
for _ in range(num_act - 1): | |
act_module.append(nn.Linear(act_dim, act_dim, False)) | |
act_module.append(nn.LayerNorm(act_dim)) | |
act_module.append(nn.ReLU(inplace=True)) | |
self.act_module = nn.ModuleList(act_module) | |
# pred. | |
self.use_focal = cfg.MODEL.SparseRCNN.USE_FOCAL | |
if self.use_focal: | |
self.class_logits = nn.Linear(d_model, num_classes) | |
raise NotImplementedError | |
else: | |
assert num_classes == 1, "Check Person Detector num_classes {}".format(num_classes) | |
if self.is_origin: | |
self.class_logits = nn.Linear(d_model, num_classes + 1) | |
else: | |
self.action_logits = nn.Linear(act_dim, num_actions) | |
# self.act = nn.Sigmoid() | |
if self.is_origin: | |
self.bboxes_delta = nn.Linear(d_model, 4) | |
self.scale_clamp = scale_clamp | |
self.bbox_weights = bbox_weights | |
def forward(self, features, bboxes, pro_features, pooler, temp_helper=None, images_whwh=None): | |
""" | |
:param bboxes: (N, nr_boxes, 4) | |
:param pro_features: (N, nr_boxes, d_model) | |
""" | |
N, nr_boxes = bboxes.shape[:2] | |
# roi_feature. | |
proposal_boxes = list() | |
for b in range(N): | |
proposal_boxes.append(Boxes(bboxes[b])) | |
# proposal_boxes: List[Boxes], Boxes(100); features: List[Tensor] (N, d_model, H', W') | |
""" where M is the total number of boxes aggregated over all N batch images. | |
batch first because every image may have different boxes | |
""" | |
roi_features = pooler(features, proposal_boxes) # roi_features (M=N*nr_boxes, d_model, 7, 7) | |
roi_features = roi_features.view(N * nr_boxes, self.d_model, -1).permute(2, 0, 1) | |
# self_att. | |
pro_features = pro_features.view(N, nr_boxes, self.d_model).permute(1, 0, 2) # (nr_boxes, N, d_model) | |
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0] | |
pro_features = pro_features + self.dropout1(pro_features2) | |
pro_features = self.norm1(pro_features) | |
# inst_interact. (nr_boxes, N, d_model) => (N, nr_boxes, d_model) => (1, N*nr_boxes, d_model) | |
pro_features = pro_features.view(nr_boxes, N, self.d_model).permute(1, 0, 2).reshape(1, N * nr_boxes, self.d_model) | |
pro_features2 = self.inst_interact(pro_features, roi_features) # (N*nr_boxes, d_model) | |
pro_features = pro_features + self.dropout2(pro_features2) # broadcast (1, N*nr_boxes, d_model) | |
obj_features = self.norm2(pro_features) | |
# obj_feature. | |
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features)))) | |
obj_features = obj_features + self.dropout3(obj_features2) | |
obj_features = self.norm3(obj_features) # (1, N*nr_boxes, d_model) | |
# (N*nr_boxes, 1, d_model) => (N*nr_boxes, d_model) | |
fc_feature = obj_features.transpose(0, 1).reshape(N * nr_boxes, -1) | |
if self.is_origin: | |
cls_feature = fc_feature.clone() | |
reg_feature = fc_feature.clone() | |
for cls_layer in self.cls_module: | |
cls_feature = cls_layer(cls_feature) | |
for reg_layer in self.reg_module: | |
reg_feature = reg_layer(reg_feature) | |
class_logits = self.class_logits(cls_feature) | |
bboxes_deltas = self.bboxes_delta(reg_feature) | |
pred_bboxes, jitter_pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4), with_jitter=self.jitter_box, images_whwh=images_whwh, N=N, nr_boxes=nr_boxes) | |
if jitter_pred_bboxes is not None: | |
return class_logits.view(N, nr_boxes, -1), pred_bboxes.view(N, nr_boxes, -1), obj_features, jitter_pred_bboxes | |
else: | |
return class_logits.view(N, nr_boxes, -1), pred_bboxes.view(N, nr_boxes, -1), obj_features, None | |
else: | |
act_feature = self.combine_action_feat(fc_feature, temp_helper, N, nr_boxes).clone() | |
for act_layer in self.act_module: | |
act_feature = act_layer(act_feature) | |
action_logits = self.action_logits(act_feature) | |
return action_logits.view(N, nr_boxes, -1), obj_features | |
def combine_action_feat(self, spatio_feat, tempo_feat, N=None, nr_boxes=None): | |
if self.combine_mode == 'sum': | |
return spatio_feat + tempo_feat | |
elif self.combine_mode == 'concat': | |
return torch.cat([spatio_feat, tempo_feat], dim=-1) | |
elif self.combine_mode == 'MHA': # MultiHeadAttention | |
tempo_feat = tempo_feat.view(N, nr_boxes, self.d_model).permute(1, 0, 2) #(nr_boxex, N, self.d_model) | |
spatio_feat = spatio_feat.view(N, nr_boxes, self.d_model).permute(1, 0, 2) | |
st_feature = self.st_attn(tempo_feat, spatio_feat, value=spatio_feat)[0] | |
return st_feature.permute(1, 0, 2).reshape(N*nr_boxes, self.d_model) # view will cause error because contiguous | |
elif self.combine_mode == 'none': | |
return spatio_feat | |
else: | |
raise NotImplementedError("Check combine type {}".format(self.combine_mode)) | |
def apply_deltas(self, deltas, boxes, with_jitter=False, images_whwh=None, N=None, nr_boxes=None): | |
""" | |
Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. | |
Args: | |
deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. | |
deltas[i] represents k potentially different class-specific | |
box transformations for the single box boxes[i]. | |
boxes (Tensor): boxes to transform, of shape (N, 4) | |
""" | |
boxes = boxes.to(deltas.dtype) | |
widths = boxes[:, 2] - boxes[:, 0] | |
heights = boxes[:, 3] - boxes[:, 1] | |
ctr_x = boxes[:, 0] + 0.5 * widths | |
ctr_y = boxes[:, 1] + 0.5 * heights | |
wx, wy, ww, wh = self.bbox_weights | |
dx = deltas[:, 0::4] / wx | |
dy = deltas[:, 1::4] / wy | |
dw = deltas[:, 2::4] / ww | |
dh = deltas[:, 3::4] / wh | |
# Prevent sending too large values into torch.exp() | |
dw = torch.clamp(dw, max=self.scale_clamp) | |
dh = torch.clamp(dh, max=self.scale_clamp) | |
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] | |
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] | |
pred_w = torch.exp(dw) * widths[:, None] | |
pred_h = torch.exp(dh) * heights[:, None] | |
pred_boxes = torch.zeros_like(deltas) | |
pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1 | |
pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1 | |
pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2 | |
pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2 | |
if not with_jitter: | |
return pred_boxes, None | |
assert images_whwh is not None | |
jitter_pred_box = torch.zeros_like(deltas) | |
if not self.training: | |
# https://github.com/MVIG-SJTU/AlphAction/blob/master/alphaction/structures/bounding_box.py#L197 | |
x_scale = 0.1 | |
y_scale = 0.05 | |
jitter_pred_box[:, 0::4] = pred_ctr_x - 0.5 * pred_w * (1 + x_scale) | |
jitter_pred_box[:, 1::4] = pred_ctr_y - 0.5 * pred_h * (1 + y_scale) | |
jitter_pred_box[:, 2::4] = pred_ctr_x + 0.5 * pred_w * (1 + x_scale) | |
jitter_pred_box[:, 3::4] = pred_ctr_y + 0.5 * pred_h * (1 + y_scale) | |
jitter_pred_box = jitter_pred_box.view(N, nr_boxes, -1) | |
for idx, (boxes_per_image, curr_whwh) in enumerate(zip(jitter_pred_box, images_whwh)): | |
jitter_pred_box[idx] = clip_boxes_tensor(boxes_per_image, curr_whwh[1], curr_whwh[0]) | |
return pred_boxes, jitter_pred_box | |
else: | |
# https://github.com/MVIG-SJTU/AlphAction/blob/master/alphaction/structures/bounding_box.py#L226 | |
jitter_x_out, jitter_x_in, jitter_y_out, jitter_y_in = 0.2, 0.1, 0.1, 0.05 | |
device = pred_boxes.device | |
def torch_uniform(rows, a=0.0, b=1.0): | |
return torch.rand(rows, 1, dtype=torch.float32, device=device) * (b - a) + a | |
num_boxes = N * nr_boxes | |
jitter_pred_box[:, 0::4] = pred_ctr_x - 0.5 * pred_w + pred_w * torch_uniform(num_boxes, -jitter_x_out, jitter_x_in) | |
jitter_pred_box[:, 1::4] = pred_ctr_y - 0.5 * pred_h + pred_h * torch_uniform(num_boxes, -jitter_y_out, jitter_y_in) | |
jitter_pred_box[:, 2::4] = pred_ctr_x + 0.5 * pred_w + pred_w * torch_uniform(num_boxes, -jitter_x_in, jitter_x_out) | |
jitter_pred_box[:, 3::4] = pred_ctr_y + 0.5 * pred_h + pred_h * torch_uniform(num_boxes, -jitter_y_in, jitter_y_out) | |
jitter_pred_box = jitter_pred_box.view(N, nr_boxes, -1) | |
for idx, (_, curr_whwh) in enumerate(zip(jitter_pred_box, images_whwh)): | |
jitter_pred_box[idx][0].clamp_(min=0, max=curr_whwh[0] - 1) | |
jitter_pred_box[idx][1].clamp_(min=0, max=curr_whwh[1] - 1) | |
jitter_pred_box[idx][2] = torch.max(torch.clamp(jitter_pred_box[idx][2], max=curr_whwh[0]-1), jitter_pred_box[idx][0] + 1) | |
jitter_pred_box[idx][3] = torch.max(torch.clamp(jitter_pred_box[idx][3], max=curr_whwh[1]-1), jitter_pred_box[idx][1] + 1) | |
jitter_pred_box[idx] = clip_boxes_tensor(jitter_pred_box[idx], curr_whwh[1], curr_whwh[0]) | |
return pred_boxes, jitter_pred_box | |
class DynamicConv(nn.Module): | |
def __init__(self, cfg, origin=True): | |
super().__init__() | |
self.hidden_dim = cfg.MODEL.SparseRCNN.HIDDEN_DIM | |
self.dim_dynamic = cfg.MODEL.SparseRCNN.DIM_DYNAMIC | |
self.num_dynamic = cfg.MODEL.SparseRCNN.NUM_DYNAMIC | |
self.num_params = self.hidden_dim * self.dim_dynamic | |
self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params) | |
self.norm1 = nn.LayerNorm(self.dim_dynamic) | |
self.norm2 = nn.LayerNorm(self.hidden_dim) | |
self.activation = nn.ReLU(inplace=True) | |
pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION | |
if origin: | |
num_output = self.hidden_dim * pooler_resolution ** 2 | |
else: | |
num_output = self.hidden_dim * cfg.DATA.NUM_FRAMES | |
self.out_layer = nn.Linear(num_output, self.hidden_dim) | |
self.norm3 = nn.LayerNorm(self.hidden_dim) | |
def forward(self, pro_features, roi_features): | |
''' | |
pro_features: (1, N * nr_boxes, self.d_model) | |
roi_features: (49, N * nr_boxes, self.d_model) | |
''' | |
features = roi_features.permute(1, 0, 2) # (N*nr_boxes, 49, 256) | |
parameters = self.dynamic_layer(pro_features).permute(1, 0, 2) # (N*nr_boxes, 1, 32768) | |
param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic) # (N*nr_boxes, 256, 64) | |
param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim) # (N*nr_boxes, 64, 256) | |
features = torch.bmm(features, param1) # (N*nr_boxes, 49, 64) | |
features = self.norm1(features) | |
features = self.activation(features) | |
features = torch.bmm(features, param2) # (N*nr_boxes, 49, 256) | |
features = self.norm2(features) | |
features = self.activation(features) | |
features = features.flatten(start_dim=1) # (N*nr_boxes, 49*256) | |
features = self.out_layer(features) | |
features = self.norm3(features) | |
features = self.activation(features) | |
return features | |
def _get_clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
def _get_activation_fn(activation): | |
"""Return an activation function given a string""" | |
if activation == "relu": | |
return F.relu | |
if activation == "gelu": | |
return F.gelu | |
if activation == "glu": | |
return F.glu | |
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") | |
class RCNNHead3D(nn.Module): | |
def __init__(self, cfg, scale_clamp: float = _DEFAULT_SCALE_CLAMP, | |
bbox_weights=(2.0, 2.0, 1.0, 1.0), origin=True): | |
super().__init__() | |
d_model = cfg.MODEL.SparseRCNN.HIDDEN_DIM | |
num_classes = cfg.MODEL.SparseRCNN.NUM_CLASSES | |
num_actions = cfg.MODEL.NUM_CLASSES | |
dim_feedforward = cfg.MODEL.SparseRCNN.DIM_FEEDFORWARD | |
nhead = cfg.MODEL.SparseRCNN.NHEADS | |
dropout = cfg.MODEL.SparseRCNN.DROPOUT | |
activation = cfg.MODEL.SparseRCNN.ACTIVATION | |
self.d_model = d_model | |
self.jitter_box = cfg.MODEL.SparseRCNN.JITTER_BOX | |
# dynamic. | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
self.inst_interact = DynamicConv(cfg, origin=False) | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.norm3 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.is_origin = origin # origin = False for action_proposal_feature head | |
def forward(self, features, bboxes, pro_features, pooler): | |
""" | |
:param bboxes: (N, nr_boxes, 4) | |
:param pro_features: (N, nr_boxes, d_model) | |
""" | |
N, nr_boxes = bboxes.shape[:2] | |
nr_frames = features.shape[2] | |
# roi_feature. | |
proposal_boxes = list() | |
for b in range(N): | |
proposal_boxes.append(Boxes(bboxes[b])) | |
roi_feats = [] | |
# only consider slow path way | |
for k in range(nr_frames): | |
frame_roi_features = pooler([features[:, :, k]], proposal_boxes) | |
frame_roi_features = F.adaptive_avg_pool2d(frame_roi_features, output_size=(1, 1)) | |
roi_feats.append(frame_roi_features) | |
roi_features = torch.stack(roi_feats, dim=2) # (N*nr_boxes, d_model, nr_frames, 1, 1) | |
# roi_features = pooler(features, proposal_boxes) | |
roi_features = roi_features.view(N * nr_boxes, self.d_model, -1).permute(2, 0, 1) | |
# self_att. | |
pro_features = pro_features.view(N, nr_boxes, self.d_model).permute(1, 0, 2) | |
pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0] | |
pro_features = pro_features + self.dropout1(pro_features2) | |
pro_features = self.norm1(pro_features) | |
# inst_interact. | |
pro_features = pro_features.view(nr_boxes, N, self.d_model).permute(1, 0, 2).reshape(1, N * nr_boxes, self.d_model) | |
pro_features2 = self.inst_interact(pro_features, roi_features) | |
pro_features = pro_features + self.dropout2(pro_features2) | |
obj_features = self.norm2(pro_features) | |
# obj_feature. | |
obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features)))) | |
obj_features = obj_features + self.dropout3(obj_features2) | |
obj_features = self.norm3(obj_features) | |
fc_feature = obj_features.transpose(0, 1).reshape(N * nr_boxes, -1) | |
return fc_feature, obj_features | |
act_feature = fc_feature.clone() | |
for act_layer in self.act_module: | |
act_feature = act_layer(act_feature) | |
action_logits = self.act(self.action_logits(act_feature)) | |
return action_logits.view(N, nr_boxes, -1), obj_features | |
class X3DHead(nn.Module): | |
""" X3D head before the global average pooling | |
copy-paste from slowfast/head_helper.py | |
Only keep layers before the global average pooling layer | |
""" | |
def __init__( | |
self, | |
dim_in, | |
dim_inner, | |
inplace_relu=True, | |
eps=1e-5, | |
bn_mmt=0.1, | |
norm_module=nn.BatchNorm3d, | |
): | |
super(X3DHead, self).__init__() | |
self.eps = eps | |
self.bn_mmt = bn_mmt | |
self.inplace_relu = inplace_relu | |
self._construct_head(dim_in, dim_inner, norm_module) | |
def _construct_head(self, dim_in, dim_inner, norm_module): | |
self.conv_5 = nn.Conv3d( | |
dim_in, | |
dim_inner, | |
kernel_size=(1, 1, 1), | |
stride=(1, 1, 1), | |
padding=(0, 0, 0), | |
bias=False, | |
) | |
self.conv_5_bn = norm_module( | |
num_features=dim_inner, eps=self.eps, momentum=self.bn_mmt | |
) | |
self.conv_5_relu = nn.ReLU(self.inplace_relu) | |
def forward(self, inputs): | |
# In its current design the X3D head is only useable for a single | |
# pathway input. | |
assert len(inputs) == 1, "Input tensor does not contain 1 pathway" | |
x = self.conv_5(inputs[0]) | |
x = self.conv_5_bn(x) | |
x = self.conv_5_relu(x) | |
return x |
Will you have the complete code, or could you pass me the complete code please?
How can I use this code for my custom dataset? Can you provide the tutorial?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
what does the input of class ResNetRoIHead: init_features, act_init_features, and temp_init_features mean?