Created
July 28, 2021 11:36
-
-
Save Multihuntr/5a898e1794808ff7c6d30efca2ff52b7 to your computer and use it in GitHub Desktop.
Mean average precision for object detection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reimplementation of: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/object_detection/metrics/mean_avg_precision.py | |
# Now with more vectorisation! | |
def precision_recall_curve_th(is_tp, confs, n_true, eps=1e-8): | |
# Sort by confs | |
order = (-confs).argsort() | |
is_tp = is_tp[order] | |
confs = confs[order] | |
# Cumulative sum true positives and number of predictions | |
TP = is_tp.cumsum(dim=0) | |
n_pred = torch.arange(len(is_tp))+1 | |
# Divide by different subsets to find recall/precision | |
precisions = TP / (n_pred + eps) | |
recalls = TP / (n_true + eps) | |
return precisions, recalls | |
def mean_average_precision(pred_boxes, true_boxes, iou_thresh=0.5, box_format="corners"): | |
""" | |
Calculates mean average precision | |
Parameters: | |
pred_boxes (list): list of lists containing all bboxes with each bboxes | |
specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2] | |
true_boxes (list): Similar as pred_boxes except all the correct ones | |
iou_threshold (float): threshold where predicted bboxes is correct | |
Returns: | |
float: mAP value across all classes given a specific IoU threshold | |
""" | |
average_precisions = [] | |
classes = set(true_boxes[:, 1].tolist()) | |
for c in classes: | |
# Get only the boxes for class with index c | |
detections = pred_boxes[pred_boxes[:, 1] == c] | |
ground_truths = true_boxes[true_boxes[:, 1] == c] | |
total_true_bboxes = len(ground_truths) | |
is_tps = [] | |
confs = [] | |
for i in set(ground_truths[:, 0].tolist()): | |
# Get only the boxes for image with index i | |
det_i = detections[detections[:, 0] == i] | |
gt_i = ground_truths[ground_truths[:, 0] == i] | |
# Calculate IoUs for all pairs of det/gt | |
ious = intersection_over_union(det_i[:, None, 2:], gt_i[None, :, 2:], box_format=box_format) | |
ious = ious.squeeze(-1) | |
# Remove all gt boxes which don't have any detections close enough | |
gt_max, _ = ious.max(dim=0) | |
ious = ious[:, gt_max >= iou_thresh] | |
# Select the first det box above iou_thresh for each remaining gt | |
_, det_max_idx = (ious >= iou_thresh).max(dim=0) | |
is_tp = torch.zeros(det_i.shape[0]) | |
is_tp[det_max_idx] = 1 | |
is_tps.append(is_tp) | |
confs.append(det_i[:, 2]) | |
is_tps = torch.cat(is_tps) | |
confs = torch.cat(confs) | |
# Find average_precision for this class | |
precision, recall = precision_recall_curve_th(is_tps, confs, total_true_bboxes) | |
precision = torch.cat((torch.tensor([1]), precision)) | |
recall = torch.cat((torch.tensor([0]), recall)) | |
average_precisions.append(torch.trapz(precision, recall)) | |
return sum(average_precisions) / len(average_precisions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment