Last active
June 14, 2022 15:20
-
-
Save bepuca/f0d6d12f702c895c72aef07584d018b5 to your computer and use it in GitHub Desktop.
object detection error analysis article - Calculate the impact each error type has on a metric
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright © 2022 Bernat Puig Camps | |
from typing import Callable, Dict, Tuple | |
import pandas as pd | |
from classify_errors import PREDS_DF_COLUMNS, TARGETS_DF_COLUMNS, ErrorType | |
def calculate_error_impact( | |
metric_name: str, | |
metric: Callable, | |
errors_df: pd.DataFrame, | |
targets_df: pd.DataFrame, | |
preds_df: pd.DataFrame, | |
) -> Dict[str, float]: | |
"""Calculate the `metric` and the independant impact each error type has on it | |
Impact is defined as the (metric_after_fixing - metric_before_fixing). | |
Note that all error impacts and the metric will not add to 1. Nonetheless, | |
the errors (and fixes) are defined in such a way that applying all fixes | |
would end up with a perfect metric score. | |
:param metric_name: Name of the metric to display for logging purposes. | |
:param metric: Callable that will be called as metric(targets_df, preds_df) | |
and returns a float. | |
:param errors_df: DataFrame with error classification for all preds and targets | |
:param targets_df: DataFrame with the targets. | |
:param preds_df: DataFrame with the predictions. | |
:return impact: Dictionary with one key for the metric without fixing and | |
one for each error type. | |
""" | |
ensure_consistency(errors_df, targets_df, preds_df) | |
metric_values = { | |
ErrorType.CLS: metric(*fix_cls_error(errors_df, targets_df, preds_df)), | |
ErrorType.LOC: metric(*fix_loc_error(errors_df, targets_df, preds_df)), | |
ErrorType.CLS_LOC: metric( | |
*fix_cls_loc_error(errors_df, targets_df, preds_df) | |
), | |
ErrorType.DUP: metric(*fix_dup_error(errors_df, targets_df, preds_df)), | |
ErrorType.BKG: metric(*fix_bkg_error(errors_df, targets_df, preds_df)), | |
ErrorType.MISS: metric( | |
*fix_miss_error(errors_df, targets_df, preds_df) | |
), | |
} | |
# Compute the metric on the actual results | |
baseline_metric = metric(targets_df, preds_df) | |
# Calculate the difference (impact) in the metric when fixing each error | |
impact = { | |
error: (error_metric - baseline_metric) | |
for error, error_metric in metric_values.items() | |
} | |
impact[metric_name] = baseline_metric | |
return impact | |
def ensure_consistency( | |
errors_df: pd.DataFrame, targets_df: pd.DataFrame, preds_df: pd.DataFrame | |
): | |
"""Make sure that all targets are preds are accounted for in errors""" | |
target_ids = set(targets_df["target_id"]) | |
pred_ids = set(preds_df["pred_id"]) | |
error_target_ids = set(errors_df.query("target_id.notnull()")["target_id"]) | |
error_pred_ids = set(errors_df.query("pred_id.notnull()")["pred_id"]) | |
if not target_ids == error_target_ids: | |
raise ValueError( | |
f"Missing target IDs in error_df: {target_ids - error_target_ids}" | |
) | |
if not pred_ids == error_pred_ids: | |
raise ValueError( | |
f"Missing pred IDs in error_df: {pred_ids - error_pred_ids}" | |
) | |
def fix_cls_error( | |
errors_df, targets_df, preds_df | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
return _fix_by_correcting_and_removing_preds( | |
errors_df, targets_df, preds_df, ErrorType.CLS | |
) | |
def fix_loc_error( | |
errors_df, targets_df, preds_df | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
return _fix_by_correcting_and_removing_preds( | |
errors_df, targets_df, preds_df, ErrorType.LOC | |
) | |
def fix_cls_loc_error( | |
errors_df, targets_df, preds_df | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
return _fix_by_removing_preds( | |
errors_df, targets_df, preds_df, ErrorType.CLS_LOC | |
) | |
def fix_bkg_error( | |
errors_df, targets_df, preds_df | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
return _fix_by_removing_preds( | |
errors_df, targets_df, preds_df, ErrorType.BKG | |
) | |
def fix_dup_error( | |
errors_df, targets_df, preds_df | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
return _fix_by_removing_preds( | |
errors_df, targets_df, preds_df, ErrorType.DUP | |
) | |
def fix_miss_error( | |
errors_df, targets_df, preds_df | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
"""Fix missed targets by removing them | |
Missed targets is the only type of errors that deals with targets rather | |
than predictions | |
:return: Fixed (`targets_df`, `errors_df`) | |
""" | |
ensure_consistency(errors_df, targets_df, preds_df) | |
targets_df = targets_df.merge( | |
# Need to filter rest of errors or multi prediction per target makes | |
# target_df bigger | |
errors_df.query("error_type == @ErrorType.MISS"), | |
on="target_id", | |
how="left", | |
).query("error_type.isnull()") | |
return targets_df[TARGETS_DF_COLUMNS], preds_df | |
def _fix_by_correcting_and_removing_preds( | |
errors_df: pd.DataFrame, | |
targets_df: pd.DataFrame, | |
preds_df: pd.DataFrame, | |
error_type: ErrorType, | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
"""Correct predictions of `error_type` of unmatched target and remove the rest | |
CLS and LOC errors are matched to targets. To assess their impact, we | |
correct the highest scoring prediction for an unmatched target | |
(no OK error for it). | |
- For CLS, we set the label to the right one. | |
- For LOC, we set the bounding box to match perfectly with the target's. | |
The non-corrected predictions of `error_type` are removed from `preds_df`. | |
The idea is to assess what happened if instead of missing a target due to an | |
incorrect prediction, we would have had a correct one instead. The ones that | |
are not highest-scoring for target would have been duplicates, so we remove | |
them. | |
:return: Fixed (`targets_df`, `errors_df`) | |
""" | |
assert error_type in { | |
ErrorType.CLS, | |
ErrorType.LOC, | |
}, f"error_type='{error_type}'" | |
ensure_consistency(errors_df, targets_df, preds_df) | |
cols_to_correct = { | |
ErrorType.CLS: ["label_id"], | |
ErrorType.LOC: ["xmin", "ymin", "xmax", "ymax"], | |
}[error_type] | |
# Add matched targets to relevant preds and sort so highest scoring is first. | |
preds_df = ( | |
preds_df.merge( | |
errors_df.query( | |
"error_type in [@ErrorType.OK, @ErrorType.CLS, @ErrorType.LOC]" | |
), | |
on="pred_id", | |
how="left", | |
) | |
.merge( | |
targets_df[["target_id"] + cols_to_correct], | |
on="target_id", | |
how="left", | |
suffixes=("", "_target"), | |
) | |
.sort_values(by="score", ascending=False) | |
) | |
to_correct = preds_df["error_type"].eq(error_type) | |
target_cols = [col + "_target" for col in cols_to_correct] | |
preds_df.loc[to_correct, cols_to_correct] = preds_df.loc[ | |
to_correct, target_cols | |
].values | |
to_drop = [] | |
for _, target_df in preds_df.groupby("target_id"): | |
if target_df["error_type"].eq(ErrorType.OK).any(): | |
# If target has a correct prediction, drop all predictions of `error_type` | |
to_drop += target_df.query("error_type == @error_type")[ | |
"pred_id" | |
].tolist() | |
elif ( | |
target_df["error_type"].eq(error_type).any() and len(target_df) > 1 | |
): | |
# If target unmatched, drop all predictions of `error_type` that are | |
# not highest score | |
to_keep = target_df["pred_id"].iloc[0] | |
to_drop += target_df.query( | |
"error_type == @error_type and pred_id != @to_keep" | |
)["pred_id"].tolist() | |
return ( | |
targets_df, | |
preds_df.query("pred_id not in @to_drop")[PREDS_DF_COLUMNS], | |
) | |
def _fix_by_removing_preds( | |
errors_df: pd.DataFrame, | |
targets_df: pd.DataFrame, | |
preds_df: pd.DataFrame, | |
error_type: ErrorType, | |
) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
"""Fix the `error_type` by removing the predictions assigned to that error | |
This is applicable to: | |
- ErrorType.CLS_LOC and ErrorType.BKG because there is no target we | |
could match it and be sure the model was "intending" to predict that. | |
- ErrorType.DUP by definition. | |
:return: Fixed (`targets_df`, `errors_df`) | |
""" | |
assert error_type in { | |
ErrorType.CLS_LOC, | |
ErrorType.BKG, | |
ErrorType.DUP, | |
}, f"error_type='{error_type}'" | |
ensure_consistency(errors_df, targets_df, preds_df) | |
preds_df = preds_df.merge(errors_df, on="pred_id", how="left").query( | |
"error_type != @error_type" | |
) | |
return targets_df, preds_df[PREDS_DF_COLUMNS] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment