Created | aliases | References | |
---|---|---|---|
2025-01-10 |
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Modified version of https://github.com/Dao-AILab/flash-attention/blob/87a1277653fc55cd615f5341255e00c69d5c00a1/flash_attn/flash_attn_triton.py | |
Experiments with attention bias by andreas.koepf | |
Main fix was "fixing the fix", e.g. removing lines of the original like: | |
``` | |
# BUG: have to store and immediately load | |
# tl.store(t_ptrs, o_scale) | |
# o_scale = tl.load(t_ptrs) | |
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
*Experimental* implementation of FlashAttention in Triton. | |
Tested with triton==2.0.0.dev20221202. | |
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions | |
other than 64: | |
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 | |
We'll update this implementation with the new Triton backend once this is fixed. | |
We use the FlashAttention implementation from Phil Tillet a starting point. | |
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
ds = load_dataset("/path/oasst1", name='ready') | |
train = ds['train'] | |
val = ds['validation'] | |
print(f'{len(train)=}') | |
print(f'{len(val)=}') | |
for i in range(5): | |
print(train[i]["message_tree_id"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import sys | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("model_name", type=str, help="checkpoint path or model name") | |
parser.add_argument("--dtype", type=str, default="auto", help="auto, fp16, bf16 or fp32") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. | |
If you don't know the answer to a question, please don't share false information." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import sys | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser( | |
description="Push checkpoints in HF transformers format to the Huggingface Hub.", |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import torch | |
def precompute_freqs_cis( | |
dim: int, end: int, theta: float = 10000.0, scaling_factor: float = 1.0 | |
) -> torch.Tensor: | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
t = torch.arange(end, device=freqs.device).float() / scaling_factor # type: ignore | |
freqs = torch.outer(t, freqs).float() # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# sent to me by tju01, thx | |
# install base tools | |
apt update | |
apt install protobuf-compiler libssl-dev gcc pkg-config g++ make | |
# install rust | |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh | |
source "$HOME/.cargo/env" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# adapted from: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.utils import checkpoint | |
from einops import rearrange, repeat | |
import triton |
NewerOlder