Skip to content

Instantly share code, notes, and snippets.

@joey00072
joey00072 / mla.py
Created December 28, 2024 16:25
multi head latent attention (MLA)
# https://x.com/shxf0072/status/1873038335427658011
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections import OrderedDict
from ohara.modules.norm import RMSNorm
from transformers import AutoTokenizer
from transformers import LlamaConfig, LlamaForCausalLM
import torch
model_name = "TinyLlama/TinyLlama_v1.1"
config = LlamaConfig.from_pretrained(model_name,attn_implementation="eager")
# injecting customs values in cfg
customs = {"segment_size":128,"delta_update":True,"use_cache":False}
py3-none-manylinux1_x86_64.whl
2024-05-02T10:08:19,717 Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)
2024-05-02T10:08:22,178 Downloading link https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (from https://pypi.org/simple/nvidia-cusparse-cu12/) (requires-python:>=3) to /tmp/pip-unpack-yun69vx1/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl
2024-05-02T10:08:22,194 Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)
2024-05-02T10:08:24,222 Downloading link https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (from https://pypi.org/simple/nvidia-nccl-cu12/) (requires-python:>=3) to /tmp/pip-unpack-yun69vx1/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl
2024-05-02T10:08:24,237 Downloading nv
This file has been truncated, but you can view the full file.
2024-05-02T10:03:33,883 Using pip 23.3.1 from /root/miniconda3/lib/python3.12/site-packages/pip (python 3.12)
2024-05-02T10:03:33,892 Non-user install because site-packages writeable
2024-05-02T10:03:33,984 Created temporary directory: /tmp/pip-build-tracker-nyz0804a
2024-05-02T10:03:33,986 Initialized build tracking at /tmp/pip-build-tracker-nyz0804a
2024-05-02T10:03:33,988 Created build tracker: /tmp/pip-build-tracker-nyz0804a
2024-05-02T10:03:33,990 Entered build tracker: /tmp/pip-build-tracker-nyz0804a
2024-05-02T10:03:33,992 Created temporary directory: /tmp/pip-install-u1_3v84s
2024-05-02T10:03:33,995 Created temporary directory: /tmp/pip-ephem-wheel-cache-4co1m56s
2024-05-02T10:03:34,020 Obtaining file:///workspace/BitBLAS
⚡ main ~/BitBLAS cat xoxoxox.log
2024-05-01T13:58:47,521 Using pip 23.0 from /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pip (python 3.10)
2024-05-01T13:58:47,521 Non-user install because site-packages writeable
2024-05-01T13:58:47,613 Created temporary directory: /tmp/pip-build-tracker-dyggwpk3
2024-05-01T13:58:47,614 Initialized build tracking at /tmp/pip-build-tracker-dyggwpk3
2024-05-01T13:58:47,614 Created build tracker: /tmp/pip-build-tracker-dyggwpk3
2024-05-01T13:58:47,614 Entered build tracker: /tmp/pip-build-tracker-dyggwpk3
2024-05-01T13:58:47,614 Created temporary directory: /tmp/pip-install-hogqz06n
2024-05-01T13:58:47,614 Created temporary directory: /tmp/pip-ephem-wheel-cache-gx8wrv34
2024-05-01T13:58:47,621 Obtaining file:///teamspace/studios/this_studio/BitBLAS