Last active
April 9, 2024 09:43
-
-
Save bepuca/1798371425b73cff60cdfa3c023ebff8 to your computer and use it in GitHub Desktop.
object detection error article - Tests for all the important code shared in the article
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright © 2022 Bernat Puig Camps | |
import pandas as pd | |
import pytest | |
from classify_errors import ( | |
PREDS_DF_COLUMNS, | |
TARGETS_DF_COLUMNS, | |
ErrorType, | |
classify_predictions_errors, | |
) | |
from error_impact import ( | |
calculate_error_impact, | |
ensure_consistency, | |
fix_bkg_error, | |
fix_cls_error, | |
fix_loc_error, | |
fix_miss_error, | |
) | |
def test_classify_predictions_errors(): | |
# Disable black to keep readibility on defined DataFrames | |
# fmt: off | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
{"target_id": 1, "image_id": 0, "label_id": 2, "xmin": 10, "ymin": 0, "xmax": 20, "ymax": 10}, | |
{"target_id": 2, "image_id": 1, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
# OK for target 0 | |
{"pred_id": 0, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# DUP for target 0 | |
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.8}, | |
# CLS for target 0 | |
{"pred_id": 2, "image_id": 0, "label_id": 2, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# LOC for target 0 | |
{"pred_id": 3, "image_id": 0, "label_id": 1, "xmin": 6, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# CLS & LOC for all targets | |
{"pred_id": 4, "image_id": 0, "label_id": 2, "xmin": 6, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# CLS for target 0, LOC for target 1 (LOC should take precedence due to correct label) | |
{"pred_id": 5, "image_id": 0, "label_id": 2, "xmin": 0, "ymin": 0, "xmax": 15, "ymax": 10, "score": 0.9}, | |
# BKG | |
{"pred_id": 6, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 10, "xmax": 0, "ymax": 20, "score": 0.9}, | |
]) | |
# fmt: on | |
expected_errors_df = pd.DataFrame.from_records( | |
[ | |
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK}, | |
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.DUP}, | |
{"pred_id": 2, "target_id": 0, "error_type": ErrorType.CLS}, | |
{"pred_id": 3, "target_id": 0, "error_type": ErrorType.LOC}, | |
{"pred_id": 4, "target_id": None, "error_type": ErrorType.CLS_LOC}, | |
{"pred_id": 5, "target_id": 1, "error_type": ErrorType.LOC}, | |
{"pred_id": 6, "target_id": None, "error_type": ErrorType.BKG}, | |
{"pred_id": None, "target_id": 2, "error_type": ErrorType.MISS}, | |
] | |
).astype( | |
{"pred_id": float, "target_id": float, "error_type": pd.StringDtype()} | |
) | |
errors_df = classify_predictions_errors(targets_df, preds_df) | |
# Sort so the indexes are the same | |
errors_df = errors_df.sort_values(by="pred_id").reset_index(drop=True) | |
expected_errors_df = expected_errors_df.sort_values( | |
by="pred_id" | |
).reset_index(drop=True) | |
pd.testing.assert_frame_equal( | |
errors_df, expected_errors_df, check_like=True | |
) | |
def test_calculate_error_impact(): | |
# fmt: off | |
# One OK pred, but error_analysis irrelevant because metric is dummy | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
errors_df = errors_df = pd.DataFrame.from_records([ | |
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK}, | |
]) | |
# fmt: on | |
def metric_fn(targets_df, preds_df): | |
return 1.0 | |
metric = metric_fn | |
metric_name = "dummy" | |
expected_impact = { | |
metric_name: 1.0, | |
ErrorType.CLS: 0.0, | |
ErrorType.LOC: 0.0, | |
ErrorType.CLS_LOC: 0.0, | |
ErrorType.DUP: 0.0, | |
ErrorType.BKG: 0.0, | |
ErrorType.MISS: 0.0, | |
} | |
impact = calculate_error_impact( | |
metric_name, metric, errors_df, targets_df, preds_df | |
) | |
assert expected_impact == impact | |
def test_ensure_consistency(): | |
# fmt: off | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 10, "ymin": 10, "xmax": 20, "ymax": 20}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
errors_df = pd.DataFrame.from_records( | |
[ | |
{"pred_id": 0, "target_id": pd.NA, "error_type": ErrorType.BKG}, | |
# MISS error for target is missing | |
] | |
) | |
# fmt: on | |
with pytest.raises(ValueError): | |
ensure_consistency(errors_df, targets_df, preds_df) | |
def test_fix_cls_error(): | |
# fmt: off | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
{"target_id": 1, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
# Correct pred low score -> should be kept | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
# CLS error and highest scoring pred, but OK pred exists -> should be removed | |
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# CLS error and higher scoring pred -> should be corrected | |
{"pred_id": 2, "image_id": 1, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# CLS error and not highest scoring pred -> should be removed | |
{"pred_id": 3, "image_id": 1, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
errors_df = pd.DataFrame.from_records( | |
[ | |
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK}, | |
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.CLS}, | |
{"pred_id": 2, "target_id": 1, "error_type": ErrorType.CLS}, | |
{"pred_id": 3, "target_id": 1, "error_type": ErrorType.CLS}, | |
] | |
) | |
expected_fixed_preds_df = pd.DataFrame.from_records([ | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
{"pred_id": 2, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
]) | |
# fmt: on | |
fixed_targets_df, fixed_preds_df = fix_cls_error( | |
errors_df, targets_df, preds_df | |
) | |
pd.testing.assert_frame_equal( | |
fixed_preds_df.set_index("pred_id"), | |
expected_fixed_preds_df.set_index("pred_id"), | |
check_like=True, | |
) | |
pd.testing.assert_frame_equal(targets_df, fixed_targets_df) | |
def test_fix_loc_error(): | |
# fmt: off | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
{"target_id": 1, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
# Correct pred low score -> should be kept | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
# LOC error and highest scoring pred, but OK pred exists -> should be removed | |
{"pred_id": 1, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# LOC error and higher scoring pred -> should be corrected | |
{"pred_id": 2, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.9}, | |
# LOC error and not highest scoring pred -> should be removed | |
{"pred_id": 3, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
errors_df = pd.DataFrame.from_records( | |
[ | |
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK}, | |
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.LOC}, | |
{"pred_id": 2, "target_id": 1, "error_type": ErrorType.LOC}, | |
{"pred_id": 3, "target_id": 1, "error_type": ErrorType.LOC}, | |
] | |
) | |
expected_fixed_preds_df = pd.DataFrame.from_records([ | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
{"pred_id": 2, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9}, | |
]) | |
# fmt: on | |
fixed_targets_df, fixed_preds_df = fix_loc_error( | |
errors_df, targets_df, preds_df | |
) | |
pd.testing.assert_frame_equal( | |
fixed_preds_df.set_index("pred_id"), | |
expected_fixed_preds_df.set_index("pred_id"), | |
check_like=True, | |
) | |
pd.testing.assert_frame_equal(targets_df, fixed_targets_df) | |
def test_fix_by_removing_preds(): | |
"""We test with BKG but it implicitly tests all fixing functions that fix by removing""" | |
# fmt: off | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
# BKG pred -> Should be removed | |
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 10, "ymin": 10, "xmax": 20, "ymax": 20, "score": 0.6}, | |
# CLS error (or any other) -> Should be kept | |
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
errors_df = pd.DataFrame.from_records( | |
[ | |
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.BKG}, | |
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.CLS}, | |
] | |
) | |
expected_fixed_preds_df = pd.DataFrame.from_records([ | |
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
# fmt: on | |
fixed_targets_df, fixed_preds_df = fix_bkg_error( | |
errors_df, targets_df, preds_df | |
) | |
pd.testing.assert_frame_equal( | |
fixed_preds_df.set_index("pred_id"), | |
expected_fixed_preds_df.set_index("pred_id"), | |
check_like=True, | |
) | |
pd.testing.assert_frame_equal(targets_df, fixed_targets_df) | |
def test_fix_by_removing_targets(): | |
# fmt: off | |
targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}, | |
# Target that will not have a prediction | |
{"target_id": 1, "image_id": 0, "label_id": 0, "xmin": 10, "ymin": 10, "xmax": 20, "ymax": 20}, | |
]) | |
preds_df = pd.DataFrame.from_records([ | |
# CLS error (no need to OK) only for target 1 | |
{"pred_id": 0, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6}, | |
# LOC error | |
{"pred_id": 1, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.6}, | |
]) | |
errors_df = pd.DataFrame.from_records( | |
[ | |
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.CLS}, | |
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.LOC}, | |
{"pred_id": pd.NA, "target_id": 1, "error_type": ErrorType.MISS}, | |
] | |
) | |
expected_fixed_targets_df = pd.DataFrame.from_records([ | |
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10} | |
]) | |
# fmt: on | |
fixed_targets_df, fixed_preds_df = fix_miss_error( | |
errors_df, targets_df, preds_df | |
) | |
pd.testing.assert_frame_equal(fixed_preds_df, preds_df) | |
pd.testing.assert_frame_equal( | |
expected_fixed_targets_df.set_index("target_id"), | |
fixed_targets_df.set_index("target_id"), | |
check_like=True, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment