Skip to content

Instantly share code, notes, and snippets.

@Pierrci
Last active October 19, 2020 16:03
Show Gist options
  • Save Pierrci/16ff6601139f07390e08900ea5bdf584 to your computer and use it in GitHub Desktop.
Save Pierrci/16ff6601139f07390e08900ea5bdf584 to your computer and use it in GitHub Desktop.
Transformers to SavedModel.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Transformers to SavedModel.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Pierrci/16ff6601139f07390e08900ea5bdf584/transformers-to-savedmodel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Lh0AAq89bdOb",
"outputId": "59b0ed2b-afb5-40e5-9420-7fa4929b6d7f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 573
}
},
"source": [
"!pip install git+https://github.com/huggingface/transformers"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting git+https://github.com/huggingface/transformers\n",
" Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-pxx3a22o\n",
" Running command git clone -q https://github.com/huggingface/transformers /tmp/pip-req-build-pxx3a22o\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied (use --upgrade to upgrade): transformers==3.3.1 from git+https://github.com/huggingface/transformers in /usr/local/lib/python3.6/dist-packages\n",
"Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (0.0.43)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (20.4)\n",
"Requirement already satisfied: tokenizers==0.9.2 in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (0.9.2)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (2.23.0)\n",
"Requirement already satisfied: sentencepiece!=0.1.92 in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (0.1.91)\n",
"Requirement already satisfied: protobuf in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (3.12.4)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (3.0.12)\n",
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (0.7)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (2019.12.20)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (4.41.1)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==3.3.1) (1.18.5)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==3.3.1) (1.15.0)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==3.3.1) (0.16.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==3.3.1) (7.1.2)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==3.3.1) (2.4.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.3.1) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.3.1) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.3.1) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==3.3.1) (1.24.3)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf->transformers==3.3.1) (50.3.0)\n",
"Building wheels for collected packages: transformers\n",
" Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for transformers: filename=transformers-3.3.1-cp36-none-any.whl size=1255164 sha256=a4ce54a60e66e5485f0f68873f651a1eeb407d3d14232c8441738643647b6080\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-vyeilqyn/wheels/70/d3/52/b3fa4f8b8ef04167ac62e5bb2accb62ae764db2a378247490e\n",
"Successfully built transformers\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "5ddCYWg9D3M_",
"outputId": "8071e03c-6a6f-4b07-82e9-fa1b4c290996",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86
}
},
"source": [
"import tensorflow as tf\n",
"from transformers import TFAutoModelForQuestionAnswering\n",
"\n",
"model = TFAutoModelForQuestionAnswering.from_pretrained('a-ware/roberta-large-squadv2', from_pt=True)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"All PyTorch model weights were used when initializing TFRobertaForQuestionAnswering.\n",
"\n",
"All the weights of TFRobertaForQuestionAnswering were initialized from the PyTorch model.\n",
"If your task is similar to the task the model of the ckeckpoint was trained on, you can already use TFRobertaForQuestionAnswering for predictions without further training.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y-3Io24fEI26",
"outputId": "bbf7d4a1-63fa-4eda-fdc5-d372d0d899f8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 159
}
},
"source": [
"concrete_function = tf.function(model.call).get_concrete_function([tf.TensorSpec([None, 384], tf.int32, name=\"input_ids\"), tf.TensorSpec([None, 384], tf.int32, name=\"attention_mask\")])\n",
"tf.saved_model.save(model, 'savedmodel', signatures=concrete_function)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
"INFO:tensorflow:Assets written to: savedmodel/assets\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LrprRSPB8uFG",
"outputId": "45082b0c-3b0e-41f9-8c37-c007a852646e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 347
}
},
"source": [
"!saved_model_cli show --dir savedmodel --tag_set serve --signature_def serving_default"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"The given SavedModel SignatureDef contains the following input(s):\n",
" inputs['attention_mask'] tensor_info:\n",
" dtype: DT_INT32\n",
" shape: (-1, 384)\n",
" name: serving_default_attention_mask:0\n",
" inputs['input_ids'] tensor_info:\n",
" dtype: DT_INT32\n",
" shape: (-1, 384)\n",
" name: serving_default_input_ids:0\n",
"The given SavedModel SignatureDef contains the following output(s):\n",
" outputs['output_0'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 384)\n",
" name: StatefulPartitionedCall:0\n",
" outputs['output_1'] tensor_info:\n",
" dtype: DT_FLOAT\n",
" shape: (-1, 384)\n",
" name: StatefulPartitionedCall:1\n",
"Method name is: tensorflow/serving/predict\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment