Skip to content

Instantly share code, notes, and snippets.

@bepuca
Last active April 9, 2024 09:43
Show Gist options
  • Save bepuca/1798371425b73cff60cdfa3c023ebff8 to your computer and use it in GitHub Desktop.
Save bepuca/1798371425b73cff60cdfa3c023ebff8 to your computer and use it in GitHub Desktop.
object detection error article - Tests for all the important code shared in the article
# Copyright © 2022 Bernat Puig Camps
import pandas as pd
import pytest
from classify_errors import (
PREDS_DF_COLUMNS,
TARGETS_DF_COLUMNS,
ErrorType,
classify_predictions_errors,
)
from error_impact import (
calculate_error_impact,
ensure_consistency,
fix_bkg_error,
fix_cls_error,
fix_loc_error,
fix_miss_error,
)
def test_classify_predictions_errors():
# Disable black to keep readibility on defined DataFrames
# fmt: off
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
{"target_id": 1, "image_id": 0, "label_id": 2, "xmin": 10, "ymin": 0, "xmax": 20, "ymax": 10},
{"target_id": 2, "image_id": 1, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
])
preds_df = pd.DataFrame.from_records([
# OK for target 0
{"pred_id": 0, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
# DUP for target 0
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.8},
# CLS for target 0
{"pred_id": 2, "image_id": 0, "label_id": 2, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
# LOC for target 0
{"pred_id": 3, "image_id": 0, "label_id": 1, "xmin": 6, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
# CLS & LOC for all targets
{"pred_id": 4, "image_id": 0, "label_id": 2, "xmin": 6, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
# CLS for target 0, LOC for target 1 (LOC should take precedence due to correct label)
{"pred_id": 5, "image_id": 0, "label_id": 2, "xmin": 0, "ymin": 0, "xmax": 15, "ymax": 10, "score": 0.9},
# BKG
{"pred_id": 6, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 10, "xmax": 0, "ymax": 20, "score": 0.9},
])
# fmt: on
expected_errors_df = pd.DataFrame.from_records(
[
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK},
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.DUP},
{"pred_id": 2, "target_id": 0, "error_type": ErrorType.CLS},
{"pred_id": 3, "target_id": 0, "error_type": ErrorType.LOC},
{"pred_id": 4, "target_id": None, "error_type": ErrorType.CLS_LOC},
{"pred_id": 5, "target_id": 1, "error_type": ErrorType.LOC},
{"pred_id": 6, "target_id": None, "error_type": ErrorType.BKG},
{"pred_id": None, "target_id": 2, "error_type": ErrorType.MISS},
]
).astype(
{"pred_id": float, "target_id": float, "error_type": pd.StringDtype()}
)
errors_df = classify_predictions_errors(targets_df, preds_df)
# Sort so the indexes are the same
errors_df = errors_df.sort_values(by="pred_id").reset_index(drop=True)
expected_errors_df = expected_errors_df.sort_values(
by="pred_id"
).reset_index(drop=True)
pd.testing.assert_frame_equal(
errors_df, expected_errors_df, check_like=True
)
def test_calculate_error_impact():
# fmt: off
# One OK pred, but error_analysis irrelevant because metric is dummy
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
])
preds_df = pd.DataFrame.from_records([
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
])
errors_df = errors_df = pd.DataFrame.from_records([
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK},
])
# fmt: on
def metric_fn(targets_df, preds_df):
return 1.0
metric = metric_fn
metric_name = "dummy"
expected_impact = {
metric_name: 1.0,
ErrorType.CLS: 0.0,
ErrorType.LOC: 0.0,
ErrorType.CLS_LOC: 0.0,
ErrorType.DUP: 0.0,
ErrorType.BKG: 0.0,
ErrorType.MISS: 0.0,
}
impact = calculate_error_impact(
metric_name, metric, errors_df, targets_df, preds_df
)
assert expected_impact == impact
def test_ensure_consistency():
# fmt: off
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 10, "ymin": 10, "xmax": 20, "ymax": 20},
])
preds_df = pd.DataFrame.from_records([
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
])
errors_df = pd.DataFrame.from_records(
[
{"pred_id": 0, "target_id": pd.NA, "error_type": ErrorType.BKG},
# MISS error for target is missing
]
)
# fmt: on
with pytest.raises(ValueError):
ensure_consistency(errors_df, targets_df, preds_df)
def test_fix_cls_error():
# fmt: off
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
{"target_id": 1, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
])
preds_df = pd.DataFrame.from_records([
# Correct pred low score -> should be kept
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
# CLS error and highest scoring pred, but OK pred exists -> should be removed
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
# CLS error and higher scoring pred -> should be corrected
{"pred_id": 2, "image_id": 1, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
# CLS error and not highest scoring pred -> should be removed
{"pred_id": 3, "image_id": 1, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
])
errors_df = pd.DataFrame.from_records(
[
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK},
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.CLS},
{"pred_id": 2, "target_id": 1, "error_type": ErrorType.CLS},
{"pred_id": 3, "target_id": 1, "error_type": ErrorType.CLS},
]
)
expected_fixed_preds_df = pd.DataFrame.from_records([
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
{"pred_id": 2, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
])
# fmt: on
fixed_targets_df, fixed_preds_df = fix_cls_error(
errors_df, targets_df, preds_df
)
pd.testing.assert_frame_equal(
fixed_preds_df.set_index("pred_id"),
expected_fixed_preds_df.set_index("pred_id"),
check_like=True,
)
pd.testing.assert_frame_equal(targets_df, fixed_targets_df)
def test_fix_loc_error():
# fmt: off
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
{"target_id": 1, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
])
preds_df = pd.DataFrame.from_records([
# Correct pred low score -> should be kept
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
# LOC error and highest scoring pred, but OK pred exists -> should be removed
{"pred_id": 1, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.9},
# LOC error and higher scoring pred -> should be corrected
{"pred_id": 2, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.9},
# LOC error and not highest scoring pred -> should be removed
{"pred_id": 3, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.6},
])
errors_df = pd.DataFrame.from_records(
[
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.OK},
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.LOC},
{"pred_id": 2, "target_id": 1, "error_type": ErrorType.LOC},
{"pred_id": 3, "target_id": 1, "error_type": ErrorType.LOC},
]
)
expected_fixed_preds_df = pd.DataFrame.from_records([
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
{"pred_id": 2, "image_id": 1, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.9},
])
# fmt: on
fixed_targets_df, fixed_preds_df = fix_loc_error(
errors_df, targets_df, preds_df
)
pd.testing.assert_frame_equal(
fixed_preds_df.set_index("pred_id"),
expected_fixed_preds_df.set_index("pred_id"),
check_like=True,
)
pd.testing.assert_frame_equal(targets_df, fixed_targets_df)
def test_fix_by_removing_preds():
"""We test with BKG but it implicitly tests all fixing functions that fix by removing"""
# fmt: off
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
])
preds_df = pd.DataFrame.from_records([
# BKG pred -> Should be removed
{"pred_id": 0, "image_id": 0, "label_id": 0, "xmin": 10, "ymin": 10, "xmax": 20, "ymax": 20, "score": 0.6},
# CLS error (or any other) -> Should be kept
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
])
errors_df = pd.DataFrame.from_records(
[
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.BKG},
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.CLS},
]
)
expected_fixed_preds_df = pd.DataFrame.from_records([
{"pred_id": 1, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
])
# fmt: on
fixed_targets_df, fixed_preds_df = fix_bkg_error(
errors_df, targets_df, preds_df
)
pd.testing.assert_frame_equal(
fixed_preds_df.set_index("pred_id"),
expected_fixed_preds_df.set_index("pred_id"),
check_like=True,
)
pd.testing.assert_frame_equal(targets_df, fixed_targets_df)
def test_fix_by_removing_targets():
# fmt: off
targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10},
# Target that will not have a prediction
{"target_id": 1, "image_id": 0, "label_id": 0, "xmin": 10, "ymin": 10, "xmax": 20, "ymax": 20},
])
preds_df = pd.DataFrame.from_records([
# CLS error (no need to OK) only for target 1
{"pred_id": 0, "image_id": 0, "label_id": 1, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10, "score": 0.6},
# LOC error
{"pred_id": 1, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 8, "xmax": 10, "ymax": 10, "score": 0.6},
])
errors_df = pd.DataFrame.from_records(
[
{"pred_id": 0, "target_id": 0, "error_type": ErrorType.CLS},
{"pred_id": 1, "target_id": 0, "error_type": ErrorType.LOC},
{"pred_id": pd.NA, "target_id": 1, "error_type": ErrorType.MISS},
]
)
expected_fixed_targets_df = pd.DataFrame.from_records([
{"target_id": 0, "image_id": 0, "label_id": 0, "xmin": 0, "ymin": 0, "xmax": 10, "ymax": 10}
])
# fmt: on
fixed_targets_df, fixed_preds_df = fix_miss_error(
errors_df, targets_df, preds_df
)
pd.testing.assert_frame_equal(fixed_preds_df, preds_df)
pd.testing.assert_frame_equal(
expected_fixed_targets_df.set_index("target_id"),
fixed_targets_df.set_index("target_id"),
check_like=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment