Skip to content

Instantly share code, notes, and snippets.

@rbpatt2019
Created March 30, 2020 15:36
Show Gist options
  • Save rbpatt2019/7a5a2c2409d34b8535f5150f202669b4 to your computer and use it in GitHub Desktop.
Save rbpatt2019/7a5a2c2409d34b8535f5150f202669b4 to your computer and use it in GitHub Desktop.
A convenience function for plotting heatmaps with matplotlib and seaborn
import matplotlib.pyplot as plt
import seaborn as sns
def plot_cm(cm,
xlabs,
ylabs,
data_range=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
title="Confusion Matrix",
xlabel="Predicted Labels",
ylabel="Actual Labels",
savefig=False,
path="cm.png"
):
"""Plot a 2D heatmap
I mostly use this for confusion matrices, but it gerenalises to any
2D data that can be represented as a heatmap
:PARAM: cm: 2-d array containing confusion matrix
preferably normalised to "True"
See scikitlearn.metrics.confusion_matrix
:PARAM: xlabs: list-like of x-axis tick labels
:PARAM: ylabs: list-like of y-ayis tick labels
:PARAM: title: Str, figure title
:PARAM: xlabel: str, x-axis label
:PARAM: ylabel: str, y-axis label
:PARAM: savefig: bool, whether or not to save fig
Defaul: False
:PARAM: path: str, where to save image
Only used if savefig
:RETURNS: ax: mpl axes object containing plot
"""
sns.set()
plt.figure(figsize=(12, 12))
ax = sns.heatmap(
cm,
cmap="jet",
xticklabels=xlabs,
yticklabels=ylabs,
vmin=data_range[0],
vmax=data_range[-1],
cbar_kws={
"shrink": 0.5,
"ticks": data_range
},
square=True
)
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.tick_params(
axis="both",
labelsize=8,
direction="out",
bottom=True,
left=True
)
if savefig:
plt.savefig(
path,
dpi=300,
bbox_inches="tight",
transparent=False
)
return(ax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment