Last active
May 11, 2020 14:40
-
-
Save bbennett36/f499b208135463735fe5841eb631ab4c to your computer and use it in GitHub Desktop.
pred density
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
from scipy.stats import gaussian_kde | |
def plot_prediction_density(target, probs, figsize=(8,5), | |
title='Prediction Density Plot'): | |
class_set = sorted(set(target)) | |
x_grid = np.linspace(0, 1, 1000) | |
fig, ax = plt.subplots(figsize=figsize) | |
for value in class_set: | |
arr = probs[target == value] | |
kernel = gaussian_kde(arr, bw_method='scott') | |
kde = kernel.evaluate(x_grid) | |
ax.plot(x_grid, kde, linewidth=2.5, label='Target = {}'.format(value)) | |
ax.fill_between(x_grid, kde, alpha=0.6) | |
plt.title(title) | |
plt.xlabel('Model Score') | |
plt.ylabel('Kernel Density') | |
plt.legend() | |
plt.close(fig) | |
return fig | |
# target will usually be y_test and probs will be your predicted scores. | |
# Example of how to run this - | |
plot_prediction_density(y_test, scores) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment