Last active
May 11, 2022 15:32
-
-
Save DahlitzFlorian/ab912ee089ee59a30ba10246ac74e9f9 to your computer and use it in GitHub Desktop.
Abstractive-based Text Summarization Using PreSumm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Untitled0.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aequJRxrNSgT", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Abstractive Text Summarization Using PreSumm\n", | |
"\n", | |
"## Introduction\n", | |
"\n", | |
"The goal of this notebook is to show how to utilize [PreSumm](https://github.com/nlpyang/PreSumm) to create an abstractive text summarization for a sample news article from CNN.\n", | |
"PreSumm was invented by Yang Liu and Mirella Lapata through their publication [Text Summarization with Pretrained Encoders](https://arxiv.org/abs/1908.08345).\n", | |
"\n", | |
"The usage of PreSumm as shown below is based on the [original PreSumm repository](https://github.com/nlpyang/PreSumm) and an example implemented by [Jonathan Fly](https://gist.github.com/JonathanFly/0f9864d8115c8fef49061fd31a302daf)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "nE6nC87KN5OT", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Data Preparation\n", | |
"\n", | |
"First, the fork from Ming Chen is cloned.\n", | |
"The fork contains inference and a resolves a few issue reported to the original repository from Yang and Lapata.\n", | |
"These issue were mainly dependency-related due to changed packages." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BaOa98npS6gC", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 119 | |
}, | |
"outputId": "aecefcf1-3836-47c1-c8a6-db007a9e2202" | |
}, | |
"source": [ | |
"!git clone https://github.com/mingchen62/PreSumm.git\n", | |
"%cd PreSumm" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Cloning into 'PreSumm'...\n", | |
"remote: Enumerating objects: 154, done.\u001b[K\n", | |
"remote: Total 154 (delta 0), reused 0 (delta 0), pack-reused 154\u001b[K\n", | |
"Receiving objects: 100% (154/154), 12.97 MiB | 5.83 MiB/s, done.\n", | |
"Resolving deltas: 100% (64/64), done.\n", | |
"/content/PreSumm\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "A-zDKoVMS-bL", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Second, additional dependencies are installed.\n", | |
"Notice, that `torch` and `pytorch_transformers` have pinned versions.\n", | |
"The PreSumm code only works with `torch==1.1.0` and the corresponding `pytorch_transformers==1.1.0` version." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8Er74RnFTLfc", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 748 | |
}, | |
"outputId": "f079ac6c-4dc3-41d6-befa-559d8eabf30a" | |
}, | |
"source": [ | |
"!pip install torch==1.1.0 pytorch_transformers==1.1.0 tensorboardX pyrouge" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting torch==1.1.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/69/60/f685fb2cfb3088736bafbc9bdbb455327bdc8906b606da9c9a81bae1c81e/torch-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (676.9MB)\n", | |
"\u001b[K |████████████████████████████████| 676.9MB 24kB/s \n", | |
"\u001b[?25hCollecting pytorch_transformers==1.1.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/50/89/ad0d6bb932d0a51793eaabcf1617a36ff530dc9ab9e38f765a35dc293306/pytorch_transformers-1.1.0-py3-none-any.whl (158kB)\n", | |
"\u001b[K |████████████████████████████████| 163kB 42.2MB/s \n", | |
"\u001b[?25hCollecting tensorboardX\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/af/0c/4f41bcd45db376e6fe5c619c01100e9b7531c55791b7244815bac6eac32c/tensorboardX-2.1-py2.py3-none-any.whl (308kB)\n", | |
"\u001b[K |████████████████████████████████| 317kB 42.7MB/s \n", | |
"\u001b[?25hCollecting pyrouge\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/11/85/e522dd6b36880ca19dcf7f262b22365748f56edc6f455e7b6a37d0382c32/pyrouge-0.1.3.tar.gz (60kB)\n", | |
"\u001b[K |████████████████████████████████| 61kB 7.7MB/s \n", | |
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.1.0) (1.18.5)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pytorch_transformers==1.1.0) (2.23.0)\n", | |
"Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from pytorch_transformers==1.1.0) (2019.12.20)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch_transformers==1.1.0) (4.41.1)\n", | |
"Collecting sentencepiece\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n", | |
"\u001b[K |████████████████████████████████| 1.1MB 36.2MB/s \n", | |
"\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch_transformers==1.1.0) (1.14.20)\n", | |
"Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardX) (3.12.2)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tensorboardX) (1.12.0)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_transformers==1.1.0) (2.10)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_transformers==1.1.0) (1.24.3)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_transformers==1.1.0) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch_transformers==1.1.0) (2020.6.20)\n", | |
"Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch_transformers==1.1.0) (0.3.3)\n", | |
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch_transformers==1.1.0) (0.10.0)\n", | |
"Requirement already satisfied: botocore<1.18.0,>=1.17.20 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch_transformers==1.1.0) (1.17.20)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorboardX) (49.1.0)\n", | |
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.18.0,>=1.17.20->boto3->pytorch_transformers==1.1.0) (0.15.2)\n", | |
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.18.0,>=1.17.20->boto3->pytorch_transformers==1.1.0) (2.8.1)\n", | |
"Building wheels for collected packages: pyrouge\n", | |
" Building wheel for pyrouge (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pyrouge: filename=pyrouge-0.1.3-cp36-none-any.whl size=191613 sha256=dc7f14075d64d5e2ccac610a68ebb2ca9d767c6ffd7abf26be87b9ac7bd40699\n", | |
" Stored in directory: /root/.cache/pip/wheels/75/d3/0c/e5b04e15b6b87c42e980de3931d2686e14d36e045058983599\n", | |
"Successfully built pyrouge\n", | |
"\u001b[31mERROR: torchvision 0.6.1+cu101 has requirement torch==1.5.1, but you'll have torch 1.1.0 which is incompatible.\u001b[0m\n", | |
"Installing collected packages: torch, sentencepiece, pytorch-transformers, tensorboardX, pyrouge\n", | |
" Found existing installation: torch 1.5.1+cu101\n", | |
" Uninstalling torch-1.5.1+cu101:\n", | |
" Successfully uninstalled torch-1.5.1+cu101\n", | |
"Successfully installed pyrouge-0.1.3 pytorch-transformers-1.1.0 sentencepiece-0.1.91 tensorboardX-2.1 torch-1.1.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "AdCoFZKMTVGE", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Third, the pre-trained BERT models for (abstractive) text summarization are downloaded, extracted, and moved to their final destination." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WB4NtKBETj5H", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 136 | |
}, | |
"outputId": "00164faf-c577-478e-9124-8e050a9fc29c" | |
}, | |
"source": [ | |
"%cd /content/PreSumm/models\n", | |
"!gdown https://drive.google.com/uc?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr&export=download # CNN/DM Abstractive model_step_148000.pt\n", | |
"!unzip /content/PreSumm/models/bertsumextabs_cnndm_final_model.zip\n", | |
"!mkdir /content/PreSumm/models/CNN_DailyMail_Abstractive\n", | |
"!mv /content/PreSumm/models/model_step_148000.pt /content/PreSumm/models/CNN_DailyMail_Abstractive" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content/PreSumm/models\n", | |
"Downloading...\n", | |
"From: https://drive.google.com/uc?id=1-IKVCtc4Q-BdZpjXc4s70_fRsWnjtYLr\n", | |
"To: /content/PreSumm/models/bertsumextabs_cnndm_final_model.zip\n", | |
"1.98GB [00:45, 43.6MB/s]\n", | |
"Archive: /content/PreSumm/models/bertsumextabs_cnndm_final_model.zip\n", | |
" inflating: model_step_148000.pt \n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "z9rhvPmLT11W", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"The following code creates two new directories and downloads the CNN article \"\" that is to be summarized into the directory containing the documents to summarize.\n", | |
"The article is originally from [here](https://edition.cnn.com/2020/07/18/football/arsenal-manchester-city-fa-cup-aubameyang/index.html) and was pre-processed.\n", | |
"The pre-processed version, which is used here, can be found on [Pastebin](https://pastebin.com/JumkZCTB)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "82xbV5ltUPhg", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 221 | |
}, | |
"outputId": "858d338b-ff04-440d-b764-44a15cfba5e0" | |
}, | |
"source": [ | |
"!mkdir /content/PreSumm/bert_data_test/\n", | |
"!mkdir /content/PreSumm/bert_data/cnndm\n", | |
"%cd /content/PreSumm/bert_data/cnndm\n", | |
"!wget https://pastebin.com/raw/JumkZCTB" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content/PreSumm/bert_data/cnndm\n", | |
"--2020-07-20 14:03:20-- https://pastebin.com/raw/JumkZCTB\n", | |
"Resolving pastebin.com (pastebin.com)... 104.23.98.190, 104.23.99.190, 2606:4700:10::6817:63be, ...\n", | |
"Connecting to pastebin.com (pastebin.com)|104.23.98.190|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: unspecified [text/plain]\n", | |
"Saving to: ‘JumkZCTB’\n", | |
"\n", | |
"JumkZCTB [ <=> ] 2.80K --.-KB/s in 0s \n", | |
"\n", | |
"2020-07-20 14:03:20 (44.7 MB/s) - ‘JumkZCTB’ saved [2869]\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "K9kZBD2SWCOT", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Implementing the summarizer\n", | |
"\n", | |
"In order to implement the summarizer, `nltk`'s Punkt tokenizer needs to be downloaded first.\n", | |
"Its purpose it to separate a given text into the different sentences." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OT0ubuA9VzsE", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 68 | |
}, | |
"outputId": "edb672cc-530f-43e6-af2f-c0878112eaa3" | |
}, | |
"source": [ | |
"import nltk\n", | |
"nltk.download('punkt')" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[nltk_data] Downloading package punkt to /root/nltk_data...\n", | |
"[nltk_data] Unzipping tokenizers/punkt.zip.\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UPGmH0WT2_Hl", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Next, the summarizer is implemented.\n", | |
"In essence, its a wrapper for PreSumm, providing a command-line interface to the user.\n", | |
"The arguments, which are defined in the `init_args()` function, correspond to the ones PreSumm accepts.\n", | |
"The values passed through the CLI are directly fed into PreSumm." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CkmtXUS1WarH", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "673ed549-bd9f-45ca-89b4-210eddbf1173" | |
}, | |
"source": [ | |
"%%writefile /content/PreSumm/src/summarizer.py\n", | |
"#!/usr/bin/env python\n", | |
"\"\"\"\n", | |
" Main training workflow\n", | |
"\"\"\"\n", | |
"from __future__ import division\n", | |
"\n", | |
"import argparse\n", | |
"import glob\n", | |
"import os\n", | |
"\n", | |
"from others.logging import init_logger\n", | |
"from train_abstractive import validate_abs, train_abs, baseline, test_abs, test_text_abs, load_models_abs\n", | |
"from train_extractive import train_ext, validate_ext, test_ext\n", | |
"from prepro import data_builder\n", | |
"\n", | |
"\n", | |
"def str2bool(v):\n", | |
" if v.lower() in ('yes', 'true', 't', 'y', '1'):\n", | |
" return True\n", | |
" elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n", | |
" return False\n", | |
" else:\n", | |
" raise argparse.ArgumentTypeError('Boolean value expected.')\n", | |
"\n", | |
"\n", | |
"\n", | |
"def init_args():\n", | |
" parser = argparse.ArgumentParser()\n", | |
" parser.add_argument(\"-task\", default='abs', type=str, choices=['ext', 'abs'])\n", | |
" parser.add_argument(\"-encoder\", default='bert', type=str, choices=['bert', 'baseline'])\n", | |
" parser.add_argument(\"-mode\", default='test', type=str, choices=['train', 'validate', 'test'])\n", | |
" parser.add_argument(\"-bert_data_path\", default='../../bert_data_new/cnndm')\n", | |
" parser.add_argument(\"-model_path\", default='../../models/')\n", | |
" parser.add_argument(\"-result_path\", default='../../results/cnndm')\n", | |
" parser.add_argument(\"-temp_dir\", default='../../temp')\n", | |
"\n", | |
" parser.add_argument(\"-batch_size\", default=140, type=int)\n", | |
" parser.add_argument(\"-test_batch_size\", default=200, type=int)\n", | |
"\n", | |
" parser.add_argument(\"-max_pos\", default=800, type=int)\n", | |
" parser.add_argument(\"-use_interval\", type=str2bool, nargs='?',const=True,default=True)\n", | |
" parser.add_argument(\"-large\", type=str2bool, nargs='?',const=True,default=False)\n", | |
" parser.add_argument(\"-load_from_extractive\", default='', type=str)\n", | |
"\n", | |
" parser.add_argument(\"-sep_optim\", type=str2bool, nargs='?',const=True,default=True)\n", | |
" parser.add_argument(\"-lr_bert\", default=2e-3, type=float)\n", | |
" parser.add_argument(\"-lr_dec\", default=2e-3, type=float)\n", | |
" parser.add_argument(\"-use_bert_emb\", type=str2bool, nargs='?',const=True,default=False)\n", | |
"\n", | |
" parser.add_argument(\"-share_emb\", type=str2bool, nargs='?', const=True, default=False)\n", | |
" parser.add_argument(\"-finetune_bert\", type=str2bool, nargs='?', const=True, default=True)\n", | |
" parser.add_argument(\"-dec_dropout\", default=0.2, type=float)\n", | |
" parser.add_argument(\"-dec_layers\", default=6, type=int)\n", | |
" parser.add_argument(\"-dec_hidden_size\", default=768, type=int)\n", | |
" parser.add_argument(\"-dec_heads\", default=8, type=int)\n", | |
" parser.add_argument(\"-dec_ff_size\", default=2048, type=int)\n", | |
" parser.add_argument(\"-enc_hidden_size\", default=512, type=int)\n", | |
" parser.add_argument(\"-enc_ff_size\", default=512, type=int)\n", | |
" parser.add_argument(\"-enc_dropout\", default=0.2, type=float)\n", | |
" parser.add_argument(\"-enc_layers\", default=6, type=int)\n", | |
"\n", | |
" # params for EXT\n", | |
" parser.add_argument(\"-ext_dropout\", default=0.2, type=float)\n", | |
" parser.add_argument(\"-ext_layers\", default=2, type=int)\n", | |
" parser.add_argument(\"-ext_hidden_size\", default=768, type=int)\n", | |
" parser.add_argument(\"-ext_heads\", default=8, type=int)\n", | |
" parser.add_argument(\"-ext_ff_size\", default=2048, type=int)\n", | |
"\n", | |
" parser.add_argument(\"-label_smoothing\", default=0.1, type=float)\n", | |
" parser.add_argument(\"-generator_shard_size\", default=32, type=int)\n", | |
" parser.add_argument(\"-alpha\", default=0.6, type=float)\n", | |
" parser.add_argument(\"-beam_size\", default=5, type=int)\n", | |
" parser.add_argument(\"-min_length\", default=15, type=int)\n", | |
" parser.add_argument(\"-max_length\", default=150, type=int)\n", | |
" parser.add_argument(\"-max_tgt_len\", default=140, type=int)\n", | |
"\n", | |
" # params for preprocessing\n", | |
" parser.add_argument(\"-shard_size\", default=2000, type=int)\n", | |
" parser.add_argument('-min_src_nsents', default=3, type=int)\n", | |
" parser.add_argument('-max_src_nsents', default=100, type=int)\n", | |
" parser.add_argument('-min_src_ntokens_per_sent', default=5, type=int)\n", | |
" parser.add_argument('-max_src_ntokens_per_sent', default=200, type=int)\n", | |
" parser.add_argument('-min_tgt_ntokens', default=5, type=int)\n", | |
" parser.add_argument('-max_tgt_ntokens', default=500, type=int)\n", | |
" parser.add_argument(\"-lower\", type=str2bool, nargs='?',const=True,default=True)\n", | |
" parser.add_argument(\"-use_bert_basic_tokenizer\", type=str2bool, nargs='?',const=True,default=False)\n", | |
"\n", | |
" \n", | |
" parser.add_argument(\"-param_init\", default=0, type=float)\n", | |
" parser.add_argument(\"-param_init_glorot\", type=str2bool, nargs='?',const=True,default=True)\n", | |
" parser.add_argument(\"-optim\", default='adam', type=str)\n", | |
" parser.add_argument(\"-lr\", default=1, type=float)\n", | |
" parser.add_argument(\"-beta1\", default= 0.9, type=float)\n", | |
" parser.add_argument(\"-beta2\", default=0.999, type=float)\n", | |
" parser.add_argument(\"-warmup_steps\", default=8000, type=int)\n", | |
" parser.add_argument(\"-warmup_steps_bert\", default=8000, type=int)\n", | |
" parser.add_argument(\"-warmup_steps_dec\", default=8000, type=int)\n", | |
" parser.add_argument(\"-max_grad_norm\", default=0, type=float)\n", | |
"\n", | |
" parser.add_argument(\"-save_checkpoint_steps\", default=5, type=int)\n", | |
" parser.add_argument(\"-accum_count\", default=1, type=int)\n", | |
" parser.add_argument(\"-report_every\", default=1, type=int)\n", | |
" parser.add_argument(\"-train_steps\", default=1000, type=int)\n", | |
" parser.add_argument(\"-recall_eval\", type=str2bool, nargs='?',const=True,default=False)\n", | |
"\n", | |
"\n", | |
" parser.add_argument('-visible_gpus', default='-1', type=str)\n", | |
" parser.add_argument('-gpu_ranks', default='0', type=str)\n", | |
" parser.add_argument('-log_file', default='../../logs/cnndm.log')\n", | |
" parser.add_argument('-seed', default=666, type=int)\n", | |
"\n", | |
" parser.add_argument(\"-test_all\", type=str2bool, nargs='?',const=True,default=False)\n", | |
" parser.add_argument(\"-test_from\", default='')\n", | |
" parser.add_argument(\"-test_start_from\", default=-1, type=int)\n", | |
"\n", | |
" parser.add_argument(\"-train_from\", default='')\n", | |
" parser.add_argument(\"-report_rouge\", type=str2bool, nargs='?',const=True,default=True)\n", | |
" parser.add_argument(\"-block_trigram\", type=str2bool, nargs='?', const=True, default=True)\n", | |
"\n", | |
" args = parser.parse_args()\n", | |
" args.gpu_ranks = [int(i) for i in range(len(args.visible_gpus.split(',')))]\n", | |
" args.world_size = len(args.gpu_ranks)\n", | |
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.visible_gpus\n", | |
"\n", | |
" init_logger(args.log_file)\n", | |
" device = \"cpu\" if args.visible_gpus == '-1' else \"cuda\"\n", | |
" device_id = 0 if device == \"cuda\" else -1\n", | |
"\n", | |
" return args, device_id\n", | |
"\n", | |
"if __name__ == '__main__':\n", | |
" args, device_id = init_args()\n", | |
" print(args.task, args.mode) \n", | |
"\n", | |
" cp = args.test_from\n", | |
" try:\n", | |
" \tstep = int(cp.split('.')[-2].split('_')[-1])\n", | |
" except:\n", | |
" \tstep = 0\n", | |
"\n", | |
" predictor = load_models_abs(args, device_id, cp, step)\n", | |
"\n", | |
" all_files = glob.glob(os.path.join('/content/PreSumm/bert_data/cnndm', '*'))\n", | |
" print('Files In Input Dir: ' + str(len(all_files)))\n", | |
" for file in all_files:\n", | |
" with open(file) as f:\n", | |
" source=f.read().rstrip()\n", | |
"\n", | |
" data_builder.str_format_to_bert( source, args, '../bert_data_test/cnndm.test.0.bert.pt') \n", | |
" args.bert_data_path= '../bert_data_test/cnndm'\n", | |
" test_text_abs(args, device_id, cp, step, predictor)" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Overwriting /content/PreSumm/src/summarizer.py\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zEClG-4g2jfL", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"After implementing the summarizer, it can be run via the command-line.\n", | |
"The options supplied to the summarizer will use the *CNN_DailyMail_Abstractive* model, the only one downloaded so far, to summarize the earlier downloaded CNN article." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GgWL4y9RWtph", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "6dacbf3d-3b42-4a3a-95ea-f939c173609a" | |
}, | |
"source": [ | |
"%cd /content/PreSumm/src\n", | |
"!python summarizer.py -task abs -mode test -test_from /content/PreSumm/models/CNN_DailyMail_Abstractive/model_step_148000.pt -batch_size 32 -test_batch_size 500 -bert_data_path ../bert_data/cnndm -log_file ../logs/val_abs_bert_cnndm -report_rouge False -sep_optim true -use_interval true -visible_gpus -1 -max_pos 512 -max_src_nsents 100 -max_length 200 -alpha 0.95 -min_length 50 -result_path ../results/abs_bert_cnndm_sample" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content/PreSumm/src\n", | |
"abs test\n", | |
"[2020-07-20 14:03:31,479 INFO] Loading checkpoint from /content/PreSumm/models/CNN_DailyMail_Abstractive/model_step_148000.pt\n", | |
"Namespace(accum_count=1, alpha=0.95, batch_size=32, beam_size=5, bert_data_path='../bert_data/cnndm', beta1=0.9, beta2=0.999, block_trigram=True, dec_dropout=0.2, dec_ff_size=2048, dec_heads=8, dec_hidden_size=768, dec_layers=6, enc_dropout=0.2, enc_ff_size=512, enc_hidden_size=512, enc_layers=6, encoder='bert', ext_dropout=0.2, ext_ff_size=2048, ext_heads=8, ext_hidden_size=768, ext_layers=2, finetune_bert=True, generator_shard_size=32, gpu_ranks=[0], label_smoothing=0.1, large=False, load_from_extractive='', log_file='../logs/val_abs_bert_cnndm', lower=True, lr=1, lr_bert=0.002, lr_dec=0.002, max_grad_norm=0, max_length=200, max_pos=512, max_src_nsents=100, max_src_ntokens_per_sent=200, max_tgt_len=140, max_tgt_ntokens=500, min_length=50, min_src_nsents=3, min_src_ntokens_per_sent=5, min_tgt_ntokens=5, mode='test', model_path='../../models/', optim='adam', param_init=0, param_init_glorot=True, recall_eval=False, report_every=1, report_rouge=False, result_path='../results/abs_bert_cnndm_sample', save_checkpoint_steps=5, seed=666, sep_optim=True, shard_size=2000, share_emb=False, task='abs', temp_dir='../../temp', test_all=False, test_batch_size=500, test_from='/content/PreSumm/models/CNN_DailyMail_Abstractive/model_step_148000.pt', test_start_from=-1, train_from='', train_steps=1000, use_bert_basic_tokenizer=False, use_bert_emb=False, use_interval=True, visible_gpus='-1', warmup_steps=8000, warmup_steps_bert=8000, warmup_steps_dec=8000, world_size=1)\n", | |
"[2020-07-20 14:03:33,750 INFO] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json not found in cache, downloading to /tmp/tmpgbr15xia\n", | |
"100% 433/433 [00:00<00:00, 275296.90B/s]\n", | |
"[2020-07-20 14:03:34,528 INFO] copying /tmp/tmpgbr15xia to cache at ../../temp/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", | |
"[2020-07-20 14:03:34,528 INFO] creating metadata file for ../../temp/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", | |
"[2020-07-20 14:03:34,529 INFO] removing temp file /tmp/tmpgbr15xia\n", | |
"[2020-07-20 14:03:34,529 INFO] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at ../../temp/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n", | |
"[2020-07-20 14:03:34,529 INFO] Model config {\n", | |
" \"architectures\": [\n", | |
" \"BertForMaskedLM\"\n", | |
" ],\n", | |
" \"attention_probs_dropout_prob\": 0.1,\n", | |
" \"finetuning_task\": null,\n", | |
" \"hidden_act\": \"gelu\",\n", | |
" \"hidden_dropout_prob\": 0.1,\n", | |
" \"hidden_size\": 768,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"intermediate_size\": 3072,\n", | |
" \"layer_norm_eps\": 1e-12,\n", | |
" \"max_position_embeddings\": 512,\n", | |
" \"model_type\": \"bert\",\n", | |
" \"num_attention_heads\": 12,\n", | |
" \"num_hidden_layers\": 12,\n", | |
" \"num_labels\": 2,\n", | |
" \"output_attentions\": false,\n", | |
" \"output_hidden_states\": false,\n", | |
" \"pad_token_id\": 0,\n", | |
" \"torchscript\": false,\n", | |
" \"type_vocab_size\": 2,\n", | |
" \"vocab_size\": 30522\n", | |
"}\n", | |
"\n", | |
"[2020-07-20 14:03:35,358 INFO] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin not found in cache, downloading to /tmp/tmprknxpk2g\n", | |
"100% 440473133/440473133 [00:30<00:00, 14392079.23B/s]\n", | |
"[2020-07-20 14:04:06,720 INFO] copying /tmp/tmprknxpk2g to cache at ../../temp/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", | |
"[2020-07-20 14:04:07,889 INFO] creating metadata file for ../../temp/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", | |
"[2020-07-20 14:04:07,889 INFO] removing temp file /tmp/tmprknxpk2g\n", | |
"[2020-07-20 14:04:07,935 INFO] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at ../../temp/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n", | |
"[2020-07-20 14:04:14,581 INFO] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt not found in cache, downloading to /tmp/tmpjz98a_pi\n", | |
"100% 231508/231508 [00:00<00:00, 318857.64B/s]\n", | |
"[2020-07-20 14:04:16,073 INFO] copying /tmp/tmpjz98a_pi to cache at ../../temp/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n", | |
"[2020-07-20 14:04:16,074 INFO] creating metadata file for ../../temp/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n", | |
"[2020-07-20 14:04:16,075 INFO] removing temp file /tmp/tmpjz98a_pi\n", | |
"[2020-07-20 14:04:16,075 INFO] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ../../temp/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n", | |
"Files In Input Dir: 1\n", | |
"[2020-07-20 14:04:16,998 INFO] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt not found in cache, downloading to /tmp/tmp69ne8o8x\n", | |
"100% 231508/231508 [00:00<00:00, 320421.61B/s]\n", | |
"[2020-07-20 14:04:18,483 INFO] copying /tmp/tmp69ne8o8x to cache at /root/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n", | |
"[2020-07-20 14:04:18,484 INFO] creating metadata file for /root/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n", | |
"[2020-07-20 14:04:18,484 INFO] removing temp file /tmp/tmp69ne8o8x\n", | |
"[2020-07-20 14:04:18,484 INFO] loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n", | |
"[2020-07-20 14:04:18,534 INFO] Processing A double for Aubameyang as Arsenal stuns Manchester City in FA Cup semifinal\n", | |
"PAUL GITTINGS JULY 18, 2020\n", | |
"\n", | |
"(CNN)Pierre-Emerick Aubameyang scored in each half as Arsenal stunned holder Manchester City 2-0 at Wembley on Saturday to reach a record 21st FA Cup final.\n", | |
"\n", | |
"The Gabonese striker struck in the 19th and 71st minutes as Arsenal coach Mikel Arteta gained bragging rights over Pepe Guardiola, his old mentor at City.\n", | |
"\n", | |
"Manchester City, aiming for a domestic cup double having won the English League Cup before the lockdown, could not make the most of the lion's share of possession as Arsenal defended stoutly and looked ever dangerous in counterattack.\n", | |
"\n", | |
"The opening goal came in this fashion as the Gunners launched a sweet passing move from inside their own half and Nicolas Pepe found Aubameyang, who finished neatly at the far post, all the more creditable having missed another chance moments before.\n", | |
"\n", | |
"The same pattern continued with City unable to capitalize on its pressure before Arsenal's Shkodran Mustafi forced a fine save from City keeper Ederson just before the break.\n", | |
"\n", | |
"Breakaway second goal\n", | |
"City, who had its European football ban lifted by a ruling earlier this week, had a second-half penalty appeal turned down by referee John Moss and the video assistant referee for a Mustafi challenge on Raheem Sterling, while Arsenal's rock-solid Emiliano Martinez saved well from City's Riyad Mahrez.\n", | |
"\n", | |
"The near misses came back to haunt the cupholder as Arsenal broke forward and Aubameyang raced onto Kieran Tierney's pass to shoot home his second past Ederson.\n", | |
"\n", | |
"David Luiz, who made a disastrous appearance for Arsenal in the first game of the restart against City, then epitomized his side's improvement under Arteta by brilliantly blocking a further Sterling effort, while Aymeric Laporte's thunderbolt effort went just wide.\n", | |
"\n", | |
"At the end it was a smiling Arteta, formerly an assistant manager to Guardiola at City, who was celebrating in a cavernous and empty Wembley, normally filled to capacity for such a big match.\n", | |
"\n", | |
"Luiz paid tribute to Arteta's influence since taking over at the North London club late last year, with a 2-1 victory over new Premier League champion Liverpool earlier this week further sign of progress.\n", | |
"\n", | |
"\"We have an amazing coach but we can't go from 0 to 100. We are improving. The spirit was great and I'm happy for the team because they deserve it,\" Luiz told BT Sport.\n", | |
"\n", | |
"\"We have a final to play this season and we will try and win a title for this club because this club deserves to win a trophy,\" the Brazilian added.\n", | |
"\n", | |
"The cup final is set for August 1 and Arsenal will face the winner of the second semifinal on Sunday between Chelsea and Manchester United.\n", | |
"\n", | |
"City will now focus on its Champions League dream with the second leg of its last 16 tie against new Spanish champion Real Madrid next month.\n", | |
"[2020-07-20 14:04:18,550 INFO] Saving to ../bert_data_test/cnndm.test.0.bert.pt\n", | |
"[2020-07-20 14:04:18,554 INFO] Loading test dataset from ../bert_data_test/cnndm.test.0.bert.pt, number of examples: 1\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "15DoA-s92eEe", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Last but not least, the summaries are printed." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QAZS_NzDcPdG", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
}, | |
"outputId": "6d44a491-884e-489e-a309-b20b6216dbb6" | |
}, | |
"source": [ | |
"!cat ../results/*candidate" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"pierre-emerick aubameyang scored in each half as arsenal stuns manchester city 2-0 at wembley<q>the gabonese striker struck in the 19th and 71st minutes as arsenal coach mikel arteta gained bragging rights over pepe guardiola\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2TnaDvi4tmfS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment