Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created January 14, 2025 11:03
Show Gist options
  • Save devforfu/20e0ce4a459a3bf29da13f0fae606038 to your computer and use it in GitHub Desktop.
Save devforfu/20e0ce4a459a3bf29da13f0fae606038 to your computer and use it in GitHub Desktop.
auc_class_imbalance_sketch.py
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
np.random.seed(42)
n_neg, n_pos = 900, 100
y_neg = np.zeros(n_neg)
y_pos = np.ones(n_pos)
scores_neg = np.random.normal(0.3, 0.1, n_neg)
scores_pos = np.random.normal(0.6, 0.1, n_pos)
y = np.concatenate([y_neg, y_pos])
scores = np.concatenate([scores_neg, scores_pos])
auc_overall = roc_auc_score(y, scores)
auc_neg = roc_auc_score((y == 0).astype(int), scores)
auc_pos = roc_auc_score((y == 1).astype(int), scores)
fpr, tpr, _ = roc_curve(y, scores)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f"Overall AUC = {auc_overall:.2f}")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.title("ROC Curve")
plt.legend()
plt.grid()
plt.show()
print(f"Overall AUC: {auc_overall:.2f}")
print(f"Negative Class AUC: {auc_neg:.2f}")
print(f"Positive Class AUC: {auc_pos:.2f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment