Created
September 8, 2019 21:46
-
-
Save adiprasad/6305306e9ea55732e38c038880742b4c to your computer and use it in GitHub Desktop.
CRF model training and prediction on Moby Dick tokens
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pycrfsuite\n", | |
"from collections import defaultdict\n", | |
"import os\n", | |
"from sklearn.preprocessing import LabelBinarizer\n", | |
"from sklearn.metrics import accuracy_score\n", | |
"from itertools import chain" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Helper functions for converting data in suitable input format for CRFSuite " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def return_file_path(dir_path, file_path):\n", | |
"\treturn os.path.join(dir_path, file_path)\n", | |
"\n", | |
"\n", | |
"def convert_x_i(x_i):\n", | |
"\tfeatures_dict = defaultdict()\n", | |
"\tnum_features = len(x_i)\n", | |
"\n", | |
"\tdd = defaultdict()\n", | |
"\n", | |
"\tdd['bias'] = 1.0\n", | |
"\n", | |
"\tfor idx in range(num_features):\n", | |
"\t\tpixel_i = \"pixel_\" + str(idx)\n", | |
"\t\tdd[pixel_i] = x_i[idx]\n", | |
"\n", | |
"\treturn dd\n", | |
"\n", | |
"\n", | |
"def convert_x(file_path):\n", | |
"\tx_arr = []\n", | |
"\n", | |
"\twith open(file_path, \"r\") as x_file:\n", | |
"\t\tfor x_i_str in x_file:\n", | |
"\t\t\tx_i_str = x_i_str.strip()\n", | |
"\t\t\tx_i_str_arr = x_i_str.split()\n", | |
"\t\t\tx_i = [float(x_ij) for x_ij in x_i_str_arr]\n", | |
"\n", | |
"\t\t\tx_i_features = convert_x_i(x_i)\n", | |
"\n", | |
"\t\t\tx_arr.append(x_i_features)\n", | |
"\n", | |
"\treturn x_arr\n", | |
"\n", | |
"\n", | |
"def prepare_data(data_dir, mode = \"train\"):\n", | |
"\tfile_dir = os.path.join(data_dir, \"{}_words\".format(mode))\n", | |
"\twords_file = return_file_path(data_dir, \"{}_words.txt\".format(mode))\n", | |
"\n", | |
"\tX = []\n", | |
"\tY = []\n", | |
"\n", | |
"\twith open(words_file) as f:\n", | |
"\t\tfor line in f:\n", | |
"\t\t\tline = line.strip()\n", | |
"\t\t\ti, word = line.split()\n", | |
"\n", | |
"\t\t\tx_i_file_path = return_file_path(file_dir, \"img_{}.txt\".format(i))\n", | |
"\t\t\tx_i_arr = convert_x(x_i_file_path)\n", | |
"\n", | |
"\t\t\ty_i_arr = list(word)\n", | |
"\n", | |
"\t\t\tX.append(x_i_arr)\n", | |
"\t\t\tY.append(y_i_arr)\n", | |
"\n", | |
"\treturn X, Y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Functions to train and test the model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_model(X, Y, max_iter_count, model_store = \"handwriting-reco.crfsuite\"):\n", | |
"\ttrainer = pycrfsuite.Trainer(verbose=False)\n", | |
"\n", | |
"\tfor xseq, yseq in zip(X, Y):\n", | |
"\t\ttrainer.append(xseq, yseq)\n", | |
"\n", | |
"\ttrainer.set_params({\n", | |
"\t 'c1': 1.0, # coefficient for L1 penalty\n", | |
"\t 'c2': 1e-3, # coefficient for L2 penalty\n", | |
"\t 'max_iterations': max_iter_count, # stop earlier\n", | |
"\n", | |
"\t # include transitions that are possible, but not observed\n", | |
"\t 'feature.possible_transitions': True\n", | |
"\t})\n", | |
"\n", | |
"\ttrainer.train(model_store)\n", | |
"\n", | |
"\tprint(trainer.logparser.last_iteration)\n", | |
"\n", | |
"\n", | |
"def get_preds(X, model_store = \"handwriting-reco.crfsuite\"):\n", | |
"\ttagger = pycrfsuite.Tagger()\n", | |
"\ttagger.open(model_store)\n", | |
"\tY_pred = [tagger.tag(x) for x in X]\n", | |
"\n", | |
"\treturn Y_pred\n", | |
"\n", | |
"\n", | |
"def test_model(X_test, Y_test):\n", | |
"\tY_test_pred = get_preds(X_test)\n", | |
"\t\n", | |
"\tlb = LabelBinarizer()\n", | |
"\t\n", | |
"\ty_test_combined = lb.fit_transform(list(chain.from_iterable(Y_test)))\n", | |
"\ty_pred_combined = lb.transform(list(chain.from_iterable(Y_test_pred)))\n", | |
"\n", | |
"\tprint \"Test accuracy : {}\".format(accuracy_score(y_test_combined, y_pred_combined))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Train the model for 500 iterations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'loss': 51853.452636, 'error_norm': 57.176499, 'linesearch_trials': 2, 'active_features': 3636, 'num': 500, 'time': 1.73, 'scores': {}, 'linesearch_step': 0.5, 'feature_norm': 85.522574}\n", | |
"Training successful with 500 iterations.. Enable verbose in the CRF model above and re-run to track progress\n" | |
] | |
} | |
], | |
"source": [ | |
"data_dir = './data'\n", | |
"X_train, Y_train = prepare_data(data_dir)\n", | |
"train_model(X_train, Y_train, 500)\n", | |
"\n", | |
"print \"Training successful with 500 iterations.. Enable verbose in the CRF model above and re-run to track progress\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Test the model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Test accuracy : 0.853043730931\n" | |
] | |
} | |
], | |
"source": [ | |
"X_test, Y_test = prepare_data(data_dir, mode = \"test\")\n", | |
"test_model(X_test, Y_test)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.16" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment