Skip to content

Instantly share code, notes, and snippets.

@staghado
Created December 27, 2024 11:37
Show Gist options
  • Select an option

  • Save staghado/c3688a51aadec9e0b63316d8a7227064 to your computer and use it in GitHub Desktop.

Select an option

Save staghado/c3688a51aadec9e0b63316d8a7227064 to your computer and use it in GitHub Desktop.
modernbert-flexattention.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"authorship_tag": "ABX9TyMM4rgNfrxNuLfSx7HSp2fI",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/staghado/c3688a51aadec9e0b63316d8a7227064/modernbert-flexattention.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6Tmj1tVS3rs0",
"outputId": "dad42f02-b9a8-4f4e-8e24-43d63e48232d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting git+https://github.com/staghado/transformers.git@flexattn-modernbert\n",
" Cloning https://github.com/staghado/transformers.git (to revision flexattn-modernbert) to /tmp/pip-req-build-vc3vt7nd\n",
" Running command git clone --filter=blob:none --quiet https://github.com/staghado/transformers.git /tmp/pip-req-build-vc3vt7nd\n",
" Running command git checkout -b flexattn-modernbert --track origin/flexattn-modernbert\n",
" Switched to a new branch 'flexattn-modernbert'\n",
" Branch 'flexattn-modernbert' set up to track remote branch 'flexattn-modernbert' from 'origin'.\n",
" Resolved https://github.com/staghado/transformers.git to commit 47758db01f905c9c0ccbe2e4a2082cbc9e12fcc3\n",
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (3.16.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (0.27.0)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (1.26.4)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (24.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (6.0.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (2024.11.6)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (2.32.3)\n",
"Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (0.21.0)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (0.4.5)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.48.0.dev0) (4.67.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.0.dev0) (2024.10.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.0.dev0) (4.12.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (3.4.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (2.2.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.48.0.dev0) (2024.12.14)\n"
]
}
],
"source": [
"!pip install git+https://github.com/staghado/transformers.git@flexattn-modernbert"
]
},
{
"cell_type": "code",
"source": [
"!pip uninstall torch torchvision torchaudio -y"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DHV6ghBS5tVS",
"outputId": "1304669f-a86d-45e3-da85-4af8dec9cff3"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found existing installation: torch 2.6.0.dev20241112+cu121\n",
"Uninstalling torch-2.6.0.dev20241112+cu121:\n",
" Successfully uninstalled torch-2.6.0.dev20241112+cu121\n",
"Found existing installation: torchvision 0.20.0.dev20241112+cu121\n",
"Uninstalling torchvision-0.20.0.dev20241112+cu121:\n",
" Successfully uninstalled torchvision-0.20.0.dev20241112+cu121\n",
"Found existing installation: torchaudio 2.5.0.dev20241112+cu121\n",
"Uninstalling torchaudio-2.5.0.dev20241112+cu121:\n",
" Successfully uninstalled torchaudio-2.5.0.dev20241112+cu121\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Afqs0l7R5pQz",
"outputId": "62e22b30-bcb5-41b6-def6-616f8a46d278"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/nightly/cu121\n",
"Collecting torch\n",
" Using cached https://download.pytorch.org/whl/nightly/cu121/torch-2.6.0.dev20241112%2Bcu121-cp310-cp310-linux_x86_64.whl (767.9 MB)\n",
"Collecting torchvision\n",
" Using cached https://download.pytorch.org/whl/nightly/cu121/torchvision-0.20.0.dev20241112%2Bcu121-cp310-cp310-linux_x86_64.whl (7.4 MB)\n",
"Collecting torchaudio\n",
" Using cached https://download.pytorch.org/whl/nightly/cu121/torchaudio-2.5.0.dev20241112%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.10.0)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch) (9.1.0.70)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.10/dist-packages (from torch) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
"Requirement already satisfied: pytorch-triton==3.1.0+cf34004b8a in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.0+cf34004b8a)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.85)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n",
"Installing collected packages: torch, torchvision, torchaudio\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"fastai 2.7.18 requires torch<2.6,>=1.10, but you have torch 2.6.0.dev20241112+cu121 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0mSuccessfully installed torch-2.6.0.dev20241112+cu121 torchaudio-2.5.0.dev20241112+cu121 torchvision-0.20.0.dev20241112+cu121\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import time\n",
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
"\n",
"model_id = \"answerdotai/ModernBERT-base\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
"texts = [\n",
" \"The capital of France is [MASK].\" *100,\n",
" # \"The largest city in Canada is [MASK].\"*200,\n",
" # \"The currency of Japan is [MASK].\"*300,\n",
" # \"The highest mountain in the world is [MASK].\"*500\n",
"]\n",
"\n",
"implementations = [\"flex_attention\", \"sdpa\", \"eager\"]\n",
"num_repeats = 3\n",
"\n",
"def time_model(attn_implementation, text):\n",
" model = AutoModelForMaskedLM.from_pretrained(model_id,\n",
" attn_implementation=attn_implementation\n",
" ).to(\"cuda\")\n",
" inputs = tokenizer(text, return_tensors=\"pt\")\n",
" inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
" print(f\"Sequence length : {inputs['input_ids'].shape}\")\n",
" total_time = 0\n",
" for _ in range(num_repeats):\n",
" start_time = time.time()\n",
" outputs = model(**inputs)\n",
" end_time = time.time()\n",
" total_time += (end_time - start_time)\n",
"\n",
" return total_time / num_repeats\n",
"\n",
"for attn_implementation in implementations:\n",
" print(f\"Using attn_implementation={attn_implementation}\")\n",
" for text in texts:\n",
" avg_time = time_model(attn_implementation, text)\n",
" print(f\" Time taken: {avg_time:.4f} seconds\")\n",
" torch.cuda.empty_cache()\n",
" import gc; gc.collect()\n",
" print()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Rc40hQu3tP-Y",
"outputId": "17b62137-8872-4eb6-8741-9deadf509758"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Using attn_implementation=flex_attention\n",
"Sequence length : torch.Size([1, 702])\n",
" Time taken: 0.7820 seconds\n",
"\n",
"Using attn_implementation=sdpa\n",
"Sequence length : torch.Size([1, 702])\n",
" Time taken: 0.0748 seconds\n",
"\n",
"Using attn_implementation=eager\n",
"Sequence length : torch.Size([1, 702])\n",
" Time taken: 0.0679 seconds\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"torch.cuda.empty_cache()\n",
"import gc; gc.collect()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "05nAn7rixsYZ",
"outputId": "9043fd09-feff-45c9-bb97-25e5eda891bb"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "fw0pAr0OzffY"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment