Forked from julienr/sklearn_classif_report_to_latex.py
Created
August 11, 2017 08:12
-
-
Save leezu/93906686e678a790e6c6d46a484c6171 to your computer and use it in GitHub Desktop.
Parse and convert scikit-learn classification_report to latex
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code to parse sklearn classification_report | |
""" | |
## | |
import sys | |
import collections | |
## | |
def parse_classification_report(clfreport): | |
""" | |
Parse a sklearn classification report into a dict keyed by class name | |
and containing a tuple (precision, recall, fscore, support) for each class | |
""" | |
lines = clfreport.split('\n') | |
# Remove empty lines | |
lines = filter(lambda l: not len(l.strip()) == 0, lines) | |
# Starts with a header, then score for each class and finally an average | |
header = lines[0] | |
cls_lines = lines[1:-1] | |
avg_line = lines[-1] | |
assert header.split() == ['precision', 'recall', 'f1-score', 'support'] | |
assert avg_line.split()[0] == 'avg' | |
# We cannot simply use split because class names can have spaces. So instead | |
# figure the width of the class field by looking at the indentation of the | |
# precision header | |
cls_field_width = len(header) - len(header.lstrip()) | |
# Now, collect all the class names and score in a dict | |
def parse_line(l): | |
"""Parse a line of classification_report""" | |
cls_name = l[:cls_field_width].strip() | |
precision, recall, fscore, support = l[cls_field_width:].split() | |
precision = float(precision) | |
recall = float(recall) | |
fscore = float(fscore) | |
support = int(support) | |
return (cls_name, precision, recall, fscore, support) | |
data = collections.OrderedDict() | |
for l in cls_lines: | |
ret = parse_line(l) | |
cls_name = ret[0] | |
scores = ret[1:] | |
data[cls_name] = scores | |
# average | |
data['avg'] = parse_line(avg_line)[1:] | |
return data | |
#parse_classification_report(clfreport) | |
## | |
def report_to_latex_table(data): | |
out = "" | |
out += "\\begin{tabular}{c | c c c c}\n" | |
out += "Class & Precision & Recall & F-score & Support\\\\\n" | |
out += "\hline\n" | |
out += "\hline\\\\\n" | |
for cls, scores in data.iteritems(): | |
out += cls + " & " + " & ".join([str(s) for s in scores]) | |
out += "\\\\\n" | |
out += "\\end{tabular}" | |
return out | |
#print report_to_latex_table(data) | |
## | |
if __name__ == '__main__': | |
with open(sys.argv[1]) as f: | |
data = parse_classification_report(f.read()) | |
print report_to_latex_table(data) | |
## |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment