Last active
October 31, 2023 09:52
-
-
Save honglu2875/f3a1c78970ad055e758d0a9fa8e09e47 to your computer and use it in GitHub Desktop.
Aria scripts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Contains generation/sampling code""" | |
# This file contains code from https://github.com/facebookresearch/llama which | |
# is available under the following licence: | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the GNU | |
# General Public License version 3. | |
import torch | |
from typing import List | |
from aria.model import TransformerLM | |
from aria.tokenizer import Tokenizer | |
# TODO: | |
# - Enable sampling sequences longer than max_seq_len by truncating | |
# Some good settings: | |
# temp=0.85, top_p=0.9, cfg_gamma=1.4 | |
@torch.autocast(device_type="cuda", dtype=torch.float16) | |
def interpolate_sample( | |
model: TransformerLM, | |
tokenizer: Tokenizer, | |
prompts: List[list], | |
alternative: List[list], | |
max_seq_len: int, | |
max_gen_len: int, | |
force_end=False, | |
temperature: float = 0.85, | |
top_p: float = 0.9, | |
cfg_gamma: float | None = 1.2, | |
alpha: float | None = 0.3, | |
): | |
"""Performs greedy (top_p) autoregressive sampling on a batch of prompts. | |
Args: | |
model (TransformerLM): Model to sample from. | |
tokenizer (Tokenizer): Tokenizer corresponding to model. | |
prompts (List[list]): A list of prompts to sample as a batch. | |
max_seq_len (int): Maximum sequence length supported by the model. | |
max_gen_len (int): Maximum desired sequence length of the samples. | |
temperature (float, optional): Sampling temperature. Defaults to 0.75. | |
top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95. | |
Returns: | |
List[list]: The list of samples, decoded by the tokenizer. | |
""" | |
assert tokenizer.return_tensors is True, "tokenizer must return tensors." | |
model.eval() | |
pad_id = tokenizer.pad_id | |
eos_id = tokenizer.tok_to_id[tokenizer.eos_tok] | |
bsz = len(prompts) | |
min_prompt_size = min([len(t) for t in prompts]) | |
max_prompt_size = max([len(t) for t in prompts]) | |
total_len = min(max_seq_len, max_gen_len + max_prompt_size) | |
if cfg_gamma: | |
assert ( | |
min_prompt_size == max_prompt_size | |
), "CFG not supported with varying prompts" | |
if force_end: | |
assert ( | |
total_len - max_prompt_size > 130 | |
), "prompt too long to use force_end=True" | |
print( | |
f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_gen_len}" | |
) | |
tokens = torch.full((bsz, total_len), pad_id).cuda() | |
alt_tokens = torch.full((bsz, total_len), pad_id).cuda() | |
alt_len = min(total_len, min(len(a) for a in alternative)) | |
for idx, (unencoded_seq, alt_seq) in enumerate(zip(prompts, alternative)): | |
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq) | |
alt_tokens[idx, : alt_len] = tokenizer.encode(alt_seq)[:alt_len] | |
dim_tok_inserted = [False for _ in range(bsz)] | |
input_text_mask = tokens != pad_id | |
start_pos = min_prompt_size | |
past_kv = None | |
alt_kv = None | |
_use_cache = True | |
with torch.inference_mode(): | |
for cur_pos in range(start_pos, total_len): | |
token = tokens[:, :start_pos] if cur_pos == start_pos else tokens[:, cur_pos-1:cur_pos] | |
#token = tokens[:, :cur_pos] | |
logits, past_kv = model.forward(token, use_cache=_use_cache, past_kv=past_kv) | |
#logits = model.forward(token, use_cache=_use_cache, past_kv=past_kv) | |
logits = logits[:, -1, :] | |
coeff = (cur_pos - start_pos) / (total_len - start_pos) * cfg_gamma | |
if cfg_gamma and max_prompt_size < cur_pos: | |
alt_tok = alt_tokens[:, :start_pos] if cur_pos == start_pos else alt_tokens[:, cur_pos-1:cur_pos] | |
alt_logits, alt_kv = model.forward(alt_tok, use_cache=_use_cache, past_kv=alt_kv) | |
#uncond_logits = model.forward(tokens[:, :cur_pos], use_cache=_use_cache, past_kv=cfg_kv) | |
alt_logits = alt_logits[:, -1, :] | |
logits = logits + coeff * (alt_logits - logits) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = sample_top_p(probs, top_p) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1) | |
# Only replace token if prompt has already been generated | |
next_token = torch.where( | |
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | |
) | |
# Insert dim tokens | |
if force_end and cur_pos >= total_len - 130: | |
for _idx in range(bsz): | |
if ( | |
dim_tok_inserted[_idx] is False | |
and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur" | |
): | |
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] | |
# Update dim_tok_inserted | |
for _idx in range(bsz): | |
if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: | |
dim_tok_inserted[_idx] = True | |
tokens[:, cur_pos] = next_token | |
if cur_pos >= alt_len or (alpha is not None and coeff > alpha): | |
alt_tokens[:, cur_pos] = next_token | |
decoded = [] | |
for idx, seq in enumerate(tokens.tolist()): | |
# Cut to max gen len | |
seq = seq[: len(prompts[idx]) + max_gen_len] | |
# Cut to eos tok if any | |
try: | |
seq = seq[: seq.index(eos_id)] | |
except ValueError: | |
pass | |
decoded.append(tokenizer.decode(seq)) | |
return decoded | |
def sample_top_p(probs, p): | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Includes (PyTorch) transformer model and config classes.""" | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn as nn | |
from torch.nn import functional as F | |
class ModelConfig: | |
def __init__( | |
self, | |
d_model: int, | |
n_heads: int, | |
n_layers: int, | |
ff_mult: int, | |
drop_p: float, | |
max_seq_len: int, | |
grad_checkpoint: bool, | |
): | |
self.d_model = d_model | |
self.n_heads = n_heads | |
self.n_layers = n_layers | |
self.ff_mult = ff_mult | |
self.drop_p = drop_p | |
self.max_seq_len = max_seq_len | |
self.grad_checkpoint = grad_checkpoint | |
def set_vocab_size(self, vocab_size: int): | |
self.vocab_size = vocab_size | |
# Taken from GPT-NeoX see: | |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py | |
class RotaryEmbedding(torch.nn.Module): | |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
super().__init__() | |
if device is None: # todo: maybe we don't need this... | |
device = "cuda" if torch.cuda.is_available() else None | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
# Build here to make `torch.jit.trace` work. | |
self._set_cos_sin_cache( | |
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() | |
) | |
def _set_cos_sin_cache(self, seq_len, device, dtype): | |
self.max_seq_len_cached = seq_len | |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
# Different from paper, but it uses a different permutation in order to obtain the same calculation | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) | |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) | |
#self.cos_cached = emb.cos().to(dtype) | |
#self.sin_cached = emb.sin().to(dtype) | |
def forward(self, x, seq_len=None): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
if seq_len > self.max_seq_len_cached: | |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) | |
return ( | |
self.cos_cached[:seq_len].to(dtype=x.dtype), | |
self.sin_cached[:seq_len].to(dtype=x.dtype), | |
) | |
def rotate_half(x): | |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
return torch.cat( | |
(-x2, x1), dim=x1.ndim - 1 | |
) # dim=-1 triggers a bug in earlier torch versions | |
@torch.jit.script | |
def apply_rotary_pos_emb(q, k, cos, sin, past_len: int = 0): | |
"""Returns tuple (xq, xk). Expects shape (s_len, b_sz, n_head, d_head).""" | |
cos = cos[past_len:past_len + q.size(0), None, None] | |
sin = sin[past_len:past_len + q.size(0), None, None] | |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + ( | |
rotate_half(k) * sin | |
) | |
class FusedEncoderBlock(nn.Module): | |
"""Transformer block using F.scaled_dot_product_attention(). | |
This block has the following changes from a typical transformer encoder: | |
- Rotary embeddings are applied to the key/query matrices. | |
- Layer norm is applied before attention and feed forward, instead of | |
after. | |
- Keys arising from padding are masked during attention. | |
- GELU activation is used instead of ReLU. | |
Args: | |
model_config (ModelConfig): Model config settings. | |
""" | |
def __init__(self, model_config: ModelConfig): | |
super().__init__() | |
self.drop_p = model_config.drop_p | |
self.n_heads = model_config.n_heads | |
self.d_head = model_config.d_model // model_config.n_heads | |
self.max_seq_len = model_config.max_seq_len | |
# Positional embeddings | |
self.rotary_emb = RotaryEmbedding(self.d_head) | |
# Attention | |
self.mixed_qkv = nn.Linear( | |
in_features=model_config.d_model, | |
out_features=3 * model_config.d_model, | |
bias=False, | |
) | |
self.att_proj_linear = nn.Linear( | |
in_features=model_config.d_model, | |
out_features=model_config.d_model, | |
) | |
self.resid_dropout = nn.Dropout(model_config.drop_p) | |
# FF Layer | |
self.ff_dropout = nn.Dropout(model_config.drop_p) | |
self.ff_linear_1 = nn.Linear( | |
in_features=model_config.d_model, | |
out_features=model_config.d_model * model_config.ff_mult, | |
) | |
self.ff_linear_2 = nn.Linear( | |
in_features=model_config.d_model * model_config.ff_mult, | |
out_features=model_config.d_model, | |
) | |
self.ff_activation = nn.GELU() | |
# Pre layer norms | |
self.norm1 = nn.LayerNorm(model_config.d_model) | |
self.norm2 = nn.LayerNorm(model_config.d_model) | |
def forward(self, x: torch.Tensor, use_cache=False, past_kv=None): | |
att, kv = self._att_block(self.norm1(x), use_cache=use_cache, past_kv=past_kv) | |
x = x + att | |
x = x + self._ff_block(self.norm2(x)) | |
return x, kv | |
def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None): | |
batch_size, seq_len, _ = x.shape | |
mixed_qkv = self.mixed_qkv(x) | |
xq, xk, xv = mixed_qkv.chunk(3, -1) | |
# Reshape for rotary embeddings | |
xq = xq.view(batch_size, seq_len, self.n_heads, self.d_head) | |
xk = xk.view(batch_size, seq_len, self.n_heads, self.d_head) | |
xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head) | |
past_len = 0 if past_kv is None else past_kv[0].size(1) | |
# apply_rotary_post_emb expects: (s_len, b_sz, n_head, d_head) | |
cos, sin = self.rotary_emb(x=xv, seq_len=seq_len + past_len) | |
xq, xk = xq.transpose(0, 1), xk.transpose(0, 1) | |
xq, xk = apply_rotary_pos_emb(q=xq, k=xk, cos=cos, sin=sin, past_len=past_len) | |
xq, xk = xq.transpose(0, 1), xk.transpose(0, 1) | |
# xq, xk: (b_sz, s_len, n_head, d_head) | |
if past_kv is not None: | |
assert len(past_kv) == 2 | |
xk = torch.concat([past_kv[0], xk], axis=1) | |
xv = torch.concat([past_kv[1], xv], axis=1) | |
kv = (xk, xv) | |
# Reshape for attention calculation: (b_sz, n_head, s_len, d_head) | |
xq = xq.transpose(1, 2) | |
xk = xk.transpose(1, 2) | |
xv = xv.transpose(1, 2) | |
# Required as we are not using a nn.Dropout layer | |
if self.training: | |
att_dropout = 0.1 # Bug? | |
else: | |
att_dropout = 0.0 | |
# Using beta torch functionality (subject to change) | |
# See - https://shorturl.at/jtI17 | |
if past_kv is None: | |
att = F.scaled_dot_product_attention( | |
query=xq, | |
key=xk, | |
value=xv, | |
dropout_p=att_dropout, | |
is_causal=True, | |
) | |
else: | |
assert xq.size(2) == 1 | |
mask = torch.ones(1, xk.size(2), dtype=bool, device=xk.device) | |
att = F.scaled_dot_product_attention( | |
query=xq, | |
key=xk, | |
value=xv, | |
dropout_p=att_dropout, | |
is_causal=False, | |
attn_mask=mask, | |
) | |
# Reshape for out: (b_sz, s_len, n_head, d_head) | |
out = att.transpose(1, 2).contiguous() | |
out = out.view(batch_size, seq_len, self.n_heads * self.d_head) | |
return self.resid_dropout(self.att_proj_linear(out)), kv if use_cache else None | |
def _ff_block(self, x: torch.Tensor): | |
x = self.ff_linear_2(self.ff_activation(self.ff_linear_1(x))) | |
return self.ff_dropout(x) | |
class Transformer(nn.Module): | |
"""Transformer decoder with no language model head. | |
Args: | |
model_config (ModelConfig): Model config settings. | |
""" | |
def __init__(self, model_config: ModelConfig): | |
super().__init__() | |
self.model_config = model_config | |
self.tok_embeddings = nn.Embedding( | |
num_embeddings=model_config.vocab_size, | |
embedding_dim=model_config.d_model, | |
) | |
self.out_layer_norm = nn.LayerNorm(model_config.d_model) | |
self.encode_layers = nn.ModuleList() | |
for _ in range(model_config.n_layers): | |
self.encode_layers.append(FusedEncoderBlock(model_config)) | |
def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): | |
"""Forward pass of Transformer. | |
Args: | |
src (torch.tensor): Input to encoder block, of shape (batch_size, | |
seq_len, d_model). | |
Returns: | |
torch.tensor: Model outputs with shape (batch_size, seq_len, | |
d_model). | |
""" | |
hidden_states = self.tok_embeddings(src) | |
assert src.shape[1] <= self.model_config.max_seq_len, "Too long." | |
# NOTE: If you want to use gradient checkpointing then you must | |
# remove torch.compile from the train script as this is not currently | |
# supported. | |
# Implements gradient checkpoints on Encoder Layers. | |
if self.model_config.grad_checkpoint is True: | |
for layer in self.encode_layers: | |
def create_custom_forward(module): | |
def custom_forward(*args): | |
return module(*args) | |
return custom_forward | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(layer), | |
hidden_states, | |
preserve_rng_state=True, | |
use_reentrant=True, | |
) | |
else: | |
new_past_kv = [] | |
past_kv = [None] * len(self.encode_layers) if past_kv is None else past_kv | |
for layer, _kv in zip(self.encode_layers, past_kv): | |
hidden_states, kv = layer(hidden_states, use_cache=use_cache, past_kv=_kv) | |
new_past_kv.append(kv) | |
return self.out_layer_norm(hidden_states), new_past_kv if use_cache else None | |
class TransformerLM(nn.Module): | |
"""Transformer decoder with head for language modelling. | |
Args: | |
model_config (ModelConfig): Model config settings. | |
""" | |
def __init__(self, model_config: ModelConfig): | |
super().__init__() | |
self.max_seq_len = model_config.max_seq_len | |
self.model = Transformer(model_config) | |
self.lm_head = nn.Linear( | |
model_config.d_model, model_config.vocab_size, bias=False | |
) | |
def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): | |
"""Forward pass of Transformer decoder with LM head. | |
Args: | |
src (torch.tensor): Input to encoder block, of shape (batch_size, | |
seq_len, d_model). | |
Returns: | |
torch.tensor: Forward pass of src through Transformer and LM head. | |
Has shape (batch_size, seq_len, vocab_size). | |
""" | |
hidden, past_kv = self.model(src, use_cache=use_cache, past_kv=past_kv) | |
logits = self.lm_head(hidden) | |
if use_cache: | |
return logits, past_kv | |
else: | |
return logits |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Contains generation/sampling code""" | |
# This file contains code from https://github.com/facebookresearch/llama which | |
# is available under the following licence: | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the GNU | |
# General Public License version 3. | |
import torch | |
from typing import List | |
from aria.model import TransformerLM | |
from aria.tokenizer import Tokenizer | |
# TODO: | |
# - Enable sampling sequences longer than max_seq_len by truncating | |
# Some good settings: | |
# temp=0.85, top_p=0.9, cfg_gamma=1.4 | |
@torch.autocast(device_type="cuda", dtype=torch.float16) | |
def greedy_sample( | |
model: TransformerLM, | |
tokenizer: Tokenizer, | |
prompts: List[list], | |
max_seq_len: int, | |
max_gen_len: int, | |
force_end=False, | |
temperature: float = 0.85, | |
top_p: float = 0.9, | |
cfg_gamma: float | None = 1.2, | |
): | |
"""Performs greedy (top_p) autoregressive sampling on a batch of prompts. | |
Args: | |
model (TransformerLM): Model to sample from. | |
tokenizer (Tokenizer): Tokenizer corresponding to model. | |
prompts (List[list]): A list of prompts to sample as a batch. | |
max_seq_len (int): Maximum sequence length supported by the model. | |
max_gen_len (int): Maximum desired sequence length of the samples. | |
temperature (float, optional): Sampling temperature. Defaults to 0.75. | |
top_p (float, optional): Parameter for top-p sampling. Defaults to 0.95. | |
Returns: | |
List[list]: The list of samples, decoded by the tokenizer. | |
""" | |
assert tokenizer.return_tensors is True, "tokenizer must return tensors." | |
model.eval() | |
pad_id = tokenizer.pad_id | |
eos_id = tokenizer.tok_to_id[tokenizer.eos_tok] | |
bsz = len(prompts) | |
min_prompt_size = min([len(t) for t in prompts]) | |
max_prompt_size = max([len(t) for t in prompts]) | |
total_len = min(max_seq_len, max_gen_len + max_prompt_size) | |
if cfg_gamma: | |
assert ( | |
min_prompt_size == max_prompt_size | |
), "CFG not supported with varying prompts" | |
if force_end: | |
assert ( | |
total_len - max_prompt_size > 130 | |
), "prompt too long to use force_end=True" | |
print( | |
f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_gen_len}" | |
) | |
tokens = torch.full((bsz, total_len), pad_id).cuda() | |
for idx, unencoded_seq in enumerate(prompts): | |
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq) | |
dim_tok_inserted = [False for _ in range(bsz)] | |
input_text_mask = tokens != pad_id | |
start_pos = min_prompt_size | |
past_kv = None | |
cfg_kv = None | |
_use_cache = True | |
with torch.inference_mode(): | |
for cur_pos in range(start_pos, total_len): | |
token = tokens[:, :start_pos] if cur_pos == start_pos else tokens[:, cur_pos-1:cur_pos] | |
#token = tokens[:, :cur_pos] | |
logits, past_kv = model.forward(token, use_cache=_use_cache, past_kv=past_kv) | |
#logits = model.forward(token, use_cache=_use_cache, past_kv=past_kv) | |
logits = logits[:, -1, :] | |
if cfg_gamma and max_prompt_size < cur_pos: | |
uncond_logits, cfg_kv = model.forward(tokens[:, cur_pos-1:cur_pos], use_cache=_use_cache, past_kv=cfg_kv) | |
#uncond_logits = model.forward(tokens[:, :cur_pos], use_cache=_use_cache, past_kv=cfg_kv) | |
uncond_logits = uncond_logits[:, -1, :] | |
logits = uncond_logits + cfg_gamma * (logits - uncond_logits) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = sample_top_p(probs, top_p) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1) | |
# Only replace token if prompt has already been generated | |
next_token = torch.where( | |
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | |
) | |
# Insert dim tokens | |
if force_end and cur_pos >= total_len - 130: | |
for _idx in range(bsz): | |
if ( | |
dim_tok_inserted[_idx] is False | |
and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur" | |
): | |
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] | |
# Update dim_tok_inserted | |
for _idx in range(bsz): | |
if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: | |
dim_tok_inserted[_idx] = True | |
tokens[:, cur_pos] = next_token | |
decoded = [] | |
for idx, seq in enumerate(tokens.tolist()): | |
# Cut to max gen len | |
seq = seq[: len(prompts[idx]) + max_gen_len] | |
# Cut to eos tok if any | |
try: | |
seq = seq[: seq.index(eos_id)] | |
except ValueError: | |
pass | |
decoded.append(tokenizer.decode(seq)) | |
return decoded | |
def sample_top_p(probs, p): | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment