Created
December 27, 2024 11:37
-
-
Save staghado/c3688a51aadec9e0b63316d8a7227064 to your computer and use it in GitHub Desktop.
modernbert-flexattention.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "authorship_tag": "ABX9TyMM4rgNfrxNuLfSx7HSp2fI", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/staghado/c3688a51aadec9e0b63316d8a7227064/modernbert-flexattention.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "6Tmj1tVS3rs0", | |
| "outputId": "dad42f02-b9a8-4f4e-8e24-43d63e48232d" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Collecting git+https://github.com/staghado/transformers.git@flexattn-modernbert\n", | |
| " Cloning https://github.com/staghado/transformers.git (to revision flexattn-modernbert) to /tmp/pip-req-build-vc3vt7nd\n", | |
| " Running command git clone --filter=blob:none --quiet https://github.com/staghado/transformers.git /tmp/pip-req-build-vc3vt7nd\n", | |
| " Running command git checkout -b flexattn-modernbert --track origin/flexattn-modernbert\n", | |
| " Switched to a new branch 'flexattn-modernbert'\n", | |
| " Branch 'flexattn-modernbert' set up to track remote branch 'flexattn-modernbert' from 'origin'.\n", | |
| " Resolved https://github.com/staghado/transformers.git to commit 47758db01f905c9c0ccbe2e4a2082cbc9e12fcc3\n", | |
| " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | |
| " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", | |
| " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", | |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (3.16.1)\n", | |
| "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (0.27.0)\n", | |
| "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (1.26.4)\n", | |
| "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (24.2)\n", | |
| "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (6.0.2)\n", | |
| "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (2024.11.6)\n", | |
| "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (2.32.3)\n", | |
| "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (0.21.0)\n", | |
| "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (0.4.5)\n", | |
| "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (4.67.1)\n", | |
| "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.0.dev0) (2024.10.0)\n", | |
| "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.0.dev0) (4.12.2)\n", | |
| "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (3.4.0)\n", | |
| "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (3.10)\n", | |
| "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (2.2.3)\n", | |
| "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (2024.12.14)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "!pip install git+https://github.com/staghado/transformers.git@flexattn-modernbert" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip uninstall torch torchvision torchaudio -y" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "DHV6ghBS5tVS", | |
| "outputId": "1304669f-a86d-45e3-da85-4af8dec9cff3" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Found existing installation: torch 2.6.0.dev20241112+cu121\n", | |
| "Uninstalling torch-2.6.0.dev20241112+cu121:\n", | |
| " Successfully uninstalled torch-2.6.0.dev20241112+cu121\n", | |
| "Found existing installation: torchvision 0.20.0.dev20241112+cu121\n", | |
| "Uninstalling torchvision-0.20.0.dev20241112+cu121:\n", | |
| " Successfully uninstalled torchvision-0.20.0.dev20241112+cu121\n", | |
| "Found existing installation: torchaudio 2.5.0.dev20241112+cu121\n", | |
| "Uninstalling torchaudio-2.5.0.dev20241112+cu121:\n", | |
| " Successfully uninstalled torchaudio-2.5.0.dev20241112+cu121\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Afqs0l7R5pQz", | |
| "outputId": "62e22b30-bcb5-41b6-def6-616f8a46d278" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Looking in indexes: https://download.pytorch.org/whl/nightly/cu121\n", | |
| "Collecting torch\n", | |
| " Using cached https://download.pytorch.org/whl/nightly/cu121/torch-2.6.0.dev20241112%2Bcu121-cp310-cp310-linux_x86_64.whl (767.9 MB)\n", | |
| "Collecting torchvision\n", | |
| " Using cached https://download.pytorch.org/whl/nightly/cu121/torchvision-0.20.0.dev20241112%2Bcu121-cp310-cp310-linux_x86_64.whl (7.4 MB)\n", | |
| "Collecting torchaudio\n", | |
| " Using cached https://download.pytorch.org/whl/nightly/cu121/torchaudio-2.5.0.dev20241112%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n", | |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", | |
| "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", | |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", | |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", | |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.10.0)\n", | |
| "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", | |
| "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", | |
| "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", | |
| "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch) (9.1.0.70)\n", | |
| "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", | |
| "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", | |
| "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", | |
| "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", | |
| "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", | |
| "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/dist-packages (from torch) (0.6.2)\n", | |
| "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.21.5)\n", | |
| "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", | |
| "Requirement already satisfied: pytorch-triton==3.1.0+cf34004b8a in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.0+cf34004b8a)\n", | |
| "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", | |
| "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.85)\n", | |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", | |
| "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", | |
| "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n", | |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n", | |
| "Installing collected packages: torch, torchvision, torchaudio\n", | |
| "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
| "fastai 2.7.18 requires torch<2.6,>=1.10, but you have torch 2.6.0.dev20241112+cu121 which is incompatible.\u001b[0m\u001b[31m\n", | |
| "\u001b[0mSuccessfully installed torch-2.6.0.dev20241112+cu121 torchaudio-2.5.0.dev20241112+cu121 torchvision-0.20.0.dev20241112+cu121\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import time\n", | |
| "import torch\n", | |
| "from transformers import AutoTokenizer, AutoModelForMaskedLM\n", | |
| "\n", | |
| "model_id = \"answerdotai/ModernBERT-base\"\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | |
| "\n", | |
| "texts = [\n", | |
| " \"The capital of France is [MASK].\" *100,\n", | |
| " # \"The largest city in Canada is [MASK].\"*200,\n", | |
| " # \"The currency of Japan is [MASK].\"*300,\n", | |
| " # \"The highest mountain in the world is [MASK].\"*500\n", | |
| "]\n", | |
| "\n", | |
| "implementations = [\"flex_attention\", \"sdpa\", \"eager\"]\n", | |
| "num_repeats = 3\n", | |
| "\n", | |
| "def time_model(attn_implementation, text):\n", | |
| " model = AutoModelForMaskedLM.from_pretrained(model_id,\n", | |
| " attn_implementation=attn_implementation\n", | |
| " ).to(\"cuda\")\n", | |
| " inputs = tokenizer(text, return_tensors=\"pt\")\n", | |
| " inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", | |
| " print(f\"Sequence length : {inputs['input_ids'].shape}\")\n", | |
| " total_time = 0\n", | |
| " for _ in range(num_repeats):\n", | |
| " start_time = time.time()\n", | |
| " outputs = model(**inputs)\n", | |
| " end_time = time.time()\n", | |
| " total_time += (end_time - start_time)\n", | |
| "\n", | |
| " return total_time / num_repeats\n", | |
| "\n", | |
| "for attn_implementation in implementations:\n", | |
| " print(f\"Using attn_implementation={attn_implementation}\")\n", | |
| " for text in texts:\n", | |
| " avg_time = time_model(attn_implementation, text)\n", | |
| " print(f\" Time taken: {avg_time:.4f} seconds\")\n", | |
| " torch.cuda.empty_cache()\n", | |
| " import gc; gc.collect()\n", | |
| " print()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Rc40hQu3tP-Y", | |
| "outputId": "17b62137-8872-4eb6-8741-9deadf509758" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Using attn_implementation=flex_attention\n", | |
| "Sequence length : torch.Size([1, 702])\n", | |
| " Time taken: 0.7820 seconds\n", | |
| "\n", | |
| "Using attn_implementation=sdpa\n", | |
| "Sequence length : torch.Size([1, 702])\n", | |
| " Time taken: 0.0748 seconds\n", | |
| "\n", | |
| "Using attn_implementation=eager\n", | |
| "Sequence length : torch.Size([1, 702])\n", | |
| " Time taken: 0.0679 seconds\n", | |
| "\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "torch.cuda.empty_cache()\n", | |
| "import gc; gc.collect()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "05nAn7rixsYZ", | |
| "outputId": "9043fd09-feff-45c9-bb97-25e5eda891bb" | |
| }, | |
| "execution_count": 10, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "0" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [], | |
| "metadata": { | |
| "id": "fw0pAr0OzffY" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment