|
""" |
|
Code to parse sklearn classification_report |
|
Original: https://gist.github.com/julienr/6b9b9a03bd8224db7b4f |
|
|
|
Modified to work with Python 3 and classification report averages |
|
""" |
|
|
|
import sys |
|
import collections |
|
|
|
def parse_classification_report(clfreport): |
|
""" |
|
Parse a sklearn classification report into a dict keyed by class name |
|
and containing a tuple (precision, recall, fscore, support) for each class |
|
""" |
|
lines = clfreport.split('\n') |
|
# Remove empty lines |
|
lines = list(filter(lambda l: not len(l.strip()) == 0, lines)) |
|
|
|
# Starts with a header, then score for each class and finally an average |
|
header = lines[0] |
|
cls_lines = lines[1:-1] |
|
avg_line = lines[-1] |
|
|
|
assert header.split() == ['precision', 'recall', 'f1-score', 'support'] |
|
assert avg_line.split()[1] == 'avg' |
|
|
|
# We cannot simply use split because class names can have spaces. So instead |
|
# figure the width of the class field by looking at the indentation of the |
|
# precision header |
|
cls_field_width = len(header) - len(header.lstrip()) |
|
# Now, collect all the class names and score in a dict |
|
def parse_line(l): |
|
"""Parse a line of classification_report""" |
|
cls_name = l[:cls_field_width].strip() |
|
precision, recall, fscore, support = l[cls_field_width:].split() |
|
precision = float(precision) |
|
recall = float(recall) |
|
fscore = float(fscore) |
|
support = int(support) |
|
return (cls_name, precision, recall, fscore, support) |
|
|
|
data = collections.OrderedDict() |
|
for l in cls_lines: |
|
ret = parse_line(l) |
|
cls_name = ret[0] |
|
scores = ret[1:] |
|
data[cls_name] = scores |
|
|
|
# average |
|
data['avg'] = parse_line(avg_line)[1:] |
|
|
|
return data |
|
|
|
def report_to_latex_table(data): |
|
avg_split = False |
|
out = "" |
|
out += "\\begin{table}\n" |
|
out += "\\caption{Latex Table from Classification Report}\n" |
|
out += "\\label{table:classification:report}\n" |
|
out += "\\centering\n" |
|
out += "\\begin{tabular}{c | c c c r}\n" |
|
out += "Class & Precision & Recall & F-score & Support\\\\\n" |
|
out += "\midrule\n" |
|
for cls, scores in data.items(): |
|
if 'micro' in cls: |
|
out += "\\midrule\n" |
|
out += cls + " & " + " & ".join([str(s) for s in scores]) |
|
out += "\\\\\n" |
|
out += "\\end{tabular}\n" |
|
out += "\\end{table}" |
|
return out |
|
|
|
if __name__ == '__main__': |
|
with open(sys.argv[1]) as f: |
|
data = parse_classification_report(f.read()) |
|
print(report_to_latex_table(data)) |
Had problems with the code since the classification report throws an extra row 'accuracy' with no values under precision and recall columns. Showed me "not enough values to unpack (expected 4, got 2)". Added a line after removing empty lines in parse_clasification_report() to remove the accuracy line from the list.
lines=[s for s in lines if 'accuracy' not in s]