Created
March 3, 2019 17:33
-
-
Save loristns/cd257b9799ef6d2add3286c41f48c30b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Conditional Random Field.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [ | |
"LhGVxuBE87jW" | |
] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"metadata": { | |
"id": "fq633pym7oLf", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"# Apprendre des séquences avec les Champs Aléatoires Conditionnel (CRF)\n", | |
"\n", | |
"Les CRF sont des algorithmes d'annotations de séquences très utilisés notamment en traitement automatiques des langues. Il s'agit en fait d'une version séquentielle de la régression logistique. De part leur simplicité, ils permettent d'apprendre avec moins de données que des réseaux de neurones récurrent (RNN) et son plus stables que ces derniers." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "LhGVxuBE87jW", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"### Import des librairies\n", | |
"\n", | |
" - PyTorch" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "WqkBN_Ha6x3t", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"from torch.nn import Sequential, Linear, Softmax\n", | |
"from torch.nn.functional import cross_entropy\n", | |
"from torch.optim import SGD" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "vAuvblDt9DSB", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Implémentation du modèle\n", | |
"\n", | |
"Le modèle ne comprend pas l'algorithme de Viterbi permettant de trouver la séquence globale optimale puisque cette implémentation est appelée à fonctionner en temps réel." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "WsTpuEZcxhtO", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"class CRF:\n", | |
" def __init__(self, n_labels, features_fn, previous_feature_max_shift=1):\n", | |
"\n", | |
" # On génère les features permettant de se baser sur les prédictions aux étapes précédentes.\n", | |
" for shift in range(1, previous_feature_max_shift + 1):\n", | |
" for label in range(n_labels):\n", | |
" features_fn.append(\n", | |
" lambda x_seq, y_seq, pos, x, is_y=label, shift=shift: self._previous_is_label_feature(x_seq, y_seq, pos, x, is_y, shift)\n", | |
" )\n", | |
"\n", | |
" self.n_labels = n_labels\n", | |
" self.n_features = len(features_fn)\n", | |
" self.features_fn = features_fn\n", | |
" self.linear_model = Sequential(Linear(self.n_features, self.n_labels),\n", | |
" Softmax(dim=1)) # Le modèle est une régression logistique multiclasse.\n", | |
"\n", | |
" @staticmethod\n", | |
" def _previous_is_label_feature(x_seq, y_seq, pos, x, is_y, shift):\n", | |
" try: return int(y_seq[pos - shift] == is_y)\n", | |
" except IndexError: return 0\n", | |
"\n", | |
" def _featurize_state(self, x_seq, y_seq, pos, x):\n", | |
" \"\"\"Execute les fonctions afin d'extraire les features (et les place dans un vecteur).\"\"\"\n", | |
" return torch.tensor([fn(x_seq, y_seq, pos, x) for fn in self.features_fn]).float()\n", | |
"\n", | |
" def _featurize_sequence(self, x_seq, y_seq):\n", | |
" \"\"\"Génère une matrice contenant les features à chaque étapes.\"\"\"\n", | |
" featurized = [self._featurize_state(x_seq, y_seq, pos, x_state) for pos, x_state in enumerate(x_seq)]\n", | |
" return torch.stack(featurized)\n", | |
"\n", | |
" def get_score(self, x_seq, y_seq):\n", | |
" \"\"\"Retourne le score de validité de la séquence `y_seq` d'après `x_seq`.\"\"\"\n", | |
" featurized = self._featurize_sequence(x_seq, y_seq)\n", | |
" \n", | |
" # On génère un \"filtre\" indiquant pour chaque étape à quelle classe l'étape appartient.\n", | |
" pred_filter = torch.zeros(len(y_seq), self.n_labels)\n", | |
"\n", | |
" for y_pos, y in enumerate(y_seq):\n", | |
" pred_filter[y_pos, y] = 1\n", | |
"\n", | |
" return torch.sum(self.linear_model(featurized) * pred_filter)\n", | |
"\n", | |
" def predict_next_label(self, x_seq, y_seq, x):\n", | |
" \"\"\"\n", | |
" Prédit la suite de la séquence `y_seq` pour une nouvelle valeur `x`.\n", | |
" Retourne pour chaque classe le score que la nouvelle valeur `y` lui appartienne.\n", | |
" \"\"\"\n", | |
" x_seq, y_seq = x_seq.copy(), y_seq.copy()\n", | |
"\n", | |
" x_seq.append(x)\n", | |
" y_seq.append(0)\n", | |
"\n", | |
" scores = torch.empty(self.n_labels)\n", | |
"\n", | |
" for label in range(self.n_labels):\n", | |
" y_seq[-1] = label\n", | |
" scores[label] = self.get_score(x_seq, y_seq)\n", | |
"\n", | |
" return scores\n", | |
"\n", | |
" def train(self, x_seqs, y_seqs, n_epoch=500, lr=0.1):\n", | |
" \"\"\"Entraîne le modèle.\"\"\"\n", | |
" \n", | |
" optimizer = SGD(self.linear_model.parameters(), lr)\n", | |
"\n", | |
" for epoch in range(n_epoch):\n", | |
" for x_seq, y_seq in zip(x_seqs, y_seqs):\n", | |
" for i, x in enumerate(x_seq):\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" label_scores = model.predict_next_label(x_seq[:i], y_seq[:i], x)\n", | |
" \n", | |
" # On compare les scores que la prochaine valeur appartienne à une classe avec la vraie classe.\n", | |
" loss = cross_entropy(label_scores.unsqueeze(0),\n", | |
" torch.tensor(y_seq[i]).unsqueeze(0))\n", | |
"\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" if epoch % 100 == 0:\n", | |
" print(\"Epoch\", epoch, \"Loss :\", float(loss))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "4O61vYf89iIt", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Test 1 : Apprendre une séquence simple\n", | |
"\n", | |
"Le modèle doit apprendre la séquence 012012012012012... Sans information autre que le chiffre précédent." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "bNBBUm0l68aw", | |
"colab_type": "code", | |
"outputId": "59693177-3c99-4568-e74e-0136b4377c4b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 102 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"model = CRF(3, [])\n", | |
"model.train(\n", | |
" [['a', 'b', 'c']*10],\n", | |
" [[0,1,2]*10],\n", | |
")" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 0 Loss : 0.8394050598144531\n", | |
"Epoch 100 Loss : 0.2642688751220703\n", | |
"Epoch 200 Loss : 0.2548484802246094\n", | |
"Epoch 300 Loss : 0.2513141632080078\n", | |
"Epoch 400 Loss : 0.2494029998779297\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "9O967Jwl6-1D", | |
"colab_type": "code", | |
"outputId": "7498106f-7736-4d95-d590-89428b493485", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"x_seq = ['a', 'b', 'c']*20\n", | |
"y_seq = [0]\n", | |
"\n", | |
"for i in range(1, len(x_seq)):\n", | |
" scores = model.predict_next_label(x_seq[:i], y_seq, x_seq[i]).tolist()\n", | |
" pred = scores.index(max(scores))\n", | |
" y_seq.append(pred)\n", | |
"\n", | |
"print(y_seq)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Q3XgItlb_87O", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"La séquence est valide, le test montre que le CRF est capable de mémoriser une séquence simple." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "NUto6XH_AP3H", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Test 2 : Extraire une adresse\n", | |
"\n", | |
"Le modèle doit réaliser une tâche de reconnaissance d'entités NER simple : il doit extraire une adresse d'une phrase (dans le contexte d'une demande de livraison).\n", | |
"\n", | |
"La classification comprend 4 classes :\n", | |
"\n", | |
"0. Si le mot ne fait pas parti de l'adresse\n", | |
"1. Si le mot correspond au numéro de rue\n", | |
"2. Si le mot correspond à la rue\n", | |
"3. Si le mot correspond à la ville\n", | |
"\n", | |
"Le modèle disposera de plusieurs *features* :\n", | |
"\n", | |
" - Si le mot commence par une majuscule\n", | |
" - Si le mot est un nombre\n", | |
" - Si le mot fait partie d'une liste de nomination commune de rue\n", | |
" - Les 5 dernières prédictions des mots précédents\n", | |
"\n" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "nedslX6FANeu", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"model = CRF(4, \n", | |
" [\n", | |
" lambda x_seq, y_seq, pos, x: int(x.istitle()),\n", | |
" lambda x_seq, y_seq, pos, x: int(x.isdigit()),\n", | |
" lambda x_seq, y_seq, pos, x: int(x.lower() in ['rue', 'avenue', 'place', 'boulevard', 'chemin', 'impasse', 'allée', 'route'])\n", | |
" ], 5)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "7ymZa4bkDBQu", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"Le set d'entrainement comprend 3 exemples (générés sur [Fake Name Generator](https://fr.fakenamegenerator.com/gen-random-fr-fr.php)) : \n", | |
"\n", | |
"- \"Livrez-moi au 62, Boulevard Amiral Courbet à Saint-Nazaire\"\n", | |
"- \"J'habite au 56, place Stanislas, Nantes\"\n", | |
"- \"Je vis 39, rue de la République à Lyon\"" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "hODdldTaCzOl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"dataset = [\n", | |
" (\"Livrez moi au 62 Boulevard Amiral Courbet à Saint Nazaire\".split(), [0, 0, 0, 1, 2, 2, 2, 0, 3, 3]),\n", | |
" (\"J habite au 56 place Stanislas Nantes\".split(), [0, 0, 0, 1, 2, 2, 3]),\n", | |
" (\"Je vis 39 rue de la République à Lyon\".split(), [0, 0, 1, 2, 2, 2, 2, 0, 3])\n", | |
"]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "FO7xQtyuFBek", | |
"colab_type": "code", | |
"outputId": "f0dc76dd-60e9-4a26-835f-e1f859d7f063", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"x, y = zip(*dataset)\n", | |
"\n", | |
"model.train(x, y, n_epoch=1000)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch 0 Loss : 1.409395456314087\n", | |
"Epoch 100 Loss : 0.7932796478271484\n", | |
"Epoch 200 Loss : 0.632418155670166\n", | |
"Epoch 300 Loss : 0.5637807846069336\n", | |
"Epoch 400 Loss : 0.5313100814819336\n", | |
"Epoch 500 Loss : 0.5110149383544922\n", | |
"Epoch 600 Loss : 0.4961566925048828\n", | |
"Epoch 700 Loss : 0.48514556884765625\n", | |
"Epoch 800 Loss : 0.4767451286315918\n", | |
"Epoch 900 Loss : 0.46979618072509766\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "OJsyzckMFMij", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"Pour tester notre modèle nous utiliserons cette phrase :\n", | |
"\n", | |
" - \"C'est au 24 chemin des Bateliers à Clermont Ferrand\"" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "1efTHNFHFGny", | |
"colab_type": "code", | |
"outputId": "0c312384-c7fe-4a49-edd6-6e21c74795d7", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"x_seq = \"C'est au 24 chemin des Bateliers à Clermont Ferrand\".split()\n", | |
"y_seq = [0]\n", | |
"\n", | |
"for i in range(1, len(x_seq)):\n", | |
" scores = model.predict_next_label(x_seq[:i], y_seq, x_seq[i]).tolist()\n", | |
" pred = scores.index(max(scores))\n", | |
" y_seq.append(pred)\n", | |
"\n", | |
"print(y_seq)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[0, 0, 1, 2, 2, 2, 0, 3, 3]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "-ktalZBnHEBD", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"L'extration d'entités est exacte, mais pour être plus robuste, notre modèle aura besoin de plus d'exemples et de plus de *features*." | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "bJc8xCBTMo6n", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment