This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14): | |
"""Prints a confusion matrix, as returned by sklearn.metrics.confusion_matrix, as a heatmap. | |
Note that due to returning the created figure object, when this funciton is called in a | |
notebook the figure willl be printed twice. To prevent this, either append ; to your | |
function call, or modify the function by commenting out the return expression. |