Created
July 20, 2023 09:36
-
-
Save KohakuBlueleaf/94f2a13e1558c32d866e431c8a579062 to your computer and use it in GitHub Desktop.
A simple implementation of retention (from https://arxiv.org/pdf/2307.08621.pdf)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
def parallel_retention( | |
q, k, v, # bsz, heads, seq_len, dim | |
decay_mask = None # heads, seq_len, seq_len | |
): | |
retention = q @ k.transpose(-1, -2) | |
retention = retention * decay_mask | |
retention = retention @ v | |
return retention | |
def recurrent_retention( | |
q, k, v, # bsz, heads, dim | |
past_kv, # bsz, heads, dim, dim | |
decay # heads, 1, 1 | |
): | |
current_kv = decay * past_kv + k.unsqueeze(-1) * v.unsqueeze(-2) | |
output = (q.unsqueeze(-2) @ current_kv).squeeze(-2) | |
return output, current_kv | |
def chunked_retention( | |
q, k, v, # bsz, heads, chunk_size, dim | |
past_kv, # bsz, heads, dim, dim | |
decay_mask, # heads, chunk_size, chunk_size | |
chunk_decay, # heads, 1, 1 | |
inner_decay, # heads, chunk_size, 1 | |
current_decay, # heads, chunk_size, 1 | |
): | |
retention = q @ k.transpose(-1, -2) | |
retention = retention * decay_mask | |
inner_retention = retention @ v | |
cross_retention = (q @ past_kv) * inner_decay | |
output = inner_retention + cross_retention | |
current_kv = chunk_decay * past_kv + (k * current_decay).transpose(-1, -2) @ v | |
return output, current_kv | |
def decay_grid(seq_len, decay, device=None): | |
i, j = torch.meshgrid( | |
torch.arange(seq_len, device=device), | |
torch.arange(seq_len, device=device), | |
indexing="ij", | |
) | |
return decay**(i-j)*(i-j>=0) | |
def retention(q, k, v, decay, chunk_size=None): | |
is_parallel = chunk_size is None | |
is_recurrent = chunk_size is not None and chunk_size == 1 | |
is_chunked = chunk_size is not None and chunk_size > 1 | |
if not isinstance(decay, torch.Tensor): | |
decay = torch.tensor(decay, device=q.device) | |
if decay.dim() <=1: | |
decay = decay.unsqueeze(-1).unsqueeze(-1) | |
if decay.size(0) == 1: | |
decay = decay.repeat(q.size(1), 1, 1) | |
b, h, seq, dim = q.size() | |
v_dim = v.size(-1) | |
if is_parallel: | |
decay_mask = decay_grid(seq, decay, device=q.device) | |
output = parallel_retention(q, k, v, decay_mask) | |
elif is_recurrent: | |
outputs = [] | |
past_kv = 0 | |
for idx in range(seq): | |
out, past_kv = recurrent_retention( | |
q[..., idx, :], | |
k[..., idx, :], | |
v[..., idx, :], | |
past_kv, | |
decay | |
) | |
outputs.append(out.unsqueeze(-2)) | |
output = torch.concatenate(outputs, dim=-2) | |
elif is_chunked: | |
outputs = [] | |
past_kv = torch.zeros(b, h, dim, v_dim, device=k.device) | |
decay_mask = decay_grid(seq, decay, device=q.device) | |
chunk_decay = decay**chunk_size | |
current_decay = decay**torch.arange(chunk_size, device=q.device).unsqueeze(-1) | |
inner_decay = decay*current_decay | |
for idx in range(seq//chunk_size): | |
out, past_kv = chunked_retention( | |
q[..., idx*chunk_size:(idx+1)*chunk_size, :], | |
k[..., idx*chunk_size:(idx+1)*chunk_size, :], | |
v[..., idx*chunk_size:(idx+1)*chunk_size, :], | |
past_kv, | |
decay_mask, | |
chunk_decay, | |
inner_decay, | |
current_decay.flip(-2), | |
) | |
outputs.append(out) | |
output = torch.concatenate(outputs, dim=-2) | |
else: | |
raise ValueError("chunk_size must be None or a positive integer") | |
return output | |
class Retention(nn.Module): | |
def __init__(self, dim=512, heads=8, gammas=0.995, gated=False): | |
self.dim = dim | |
self.heads = heads | |
if not isinstance(gammas, torch.Tensor): | |
gammas = torch.tensor(gammas) | |
gammas = gammas.view(-1) | |
if gammas.size(0) == 1: | |
gammas = gammas.repeat(self.heads) | |
else: | |
assert gammas.size(0) == self.heads | |
self.gammas = gammas.unsqueeze(-1).unsqueeze(-1) | |
self.gated = gated | |
if gated: | |
self.x_proj = nn.Linear(dim, dim) | |
self.to_q = nn.Linear(dim, dim, bias=False) | |
self.to_k = nn.Linear(dim, dim, bias=False) | |
self.to_v = nn.Linear(dim, dim, bias=False) | |
self.out_norm = nn.GroupNorm(self.heads, dim) | |
self.to_out = nn.Linear(dim, dim) | |
def forward(self, x): | |
q = self.to_q(x) | |
k = self.to_k(x) | |
v = self.to_v(x) | |
q = rearrange(q, "b l (h d) -> b h l d", d = self.dim, h = self.heads) | |
k = rearrange(k, "b l (h d) -> b h l d", d = self.dim, h = self.heads) | |
v = rearrange(v, "b l (h d) -> b h l d", d = self.dim, h = self.heads) | |
ret = retention(q, k, v, self.gammas).transpose(1, 3) | |
ret = self.out_norm(ret).transpose(3, 1) | |
if self.gated: | |
ret = ret * F.silu(self.x_proj(x)) | |
return self.to_out(ret) | |
if __name__ == '__main__': | |
# check consistency between parallel and recurrent | |
test_q = torch.randn(1, 40, 64*6, 128).cuda() | |
test_k = torch.randn(1, 40, 64*6, 128).cuda() | |
test_v = torch.randn(1, 40, 64*6, 128).cuda() | |
decay = 1 - 2**(-5-torch.arange(40, device=test_q.device).float()/10) | |
test_recurrent = retention(test_q, test_k, test_v, decay, 1) | |
test_parallel = retention(test_q, test_k, test_v, decay) | |
test_chunked = retention(test_q, test_k, test_v, decay, 64) | |
print(torch.max(F.mse_loss(test_recurrent, test_parallel, reduction="none"))) | |
print(torch.max(F.mse_loss(test_chunked, test_parallel, reduction="none"))) | |
print(torch.max(F.mse_loss(test_chunked, test_recurrent, reduction="none"))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I use different equation for chunkwise retention to ensure its consistancy
if you are interested in the detail
check this:
https://hackmd.io/@KBlueLeaf/rk442m852