Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Created July 20, 2023 09:36
Show Gist options
  • Save KohakuBlueleaf/94f2a13e1558c32d866e431c8a579062 to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/94f2a13e1558c32d866e431c8a579062 to your computer and use it in GitHub Desktop.
A simple implementation of retention (from https://arxiv.org/pdf/2307.08621.pdf)
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def parallel_retention(
q, k, v, # bsz, heads, seq_len, dim
decay_mask = None # heads, seq_len, seq_len
):
retention = q @ k.transpose(-1, -2)
retention = retention * decay_mask
retention = retention @ v
return retention
def recurrent_retention(
q, k, v, # bsz, heads, dim
past_kv, # bsz, heads, dim, dim
decay # heads, 1, 1
):
current_kv = decay * past_kv + k.unsqueeze(-1) * v.unsqueeze(-2)
output = (q.unsqueeze(-2) @ current_kv).squeeze(-2)
return output, current_kv
def chunked_retention(
q, k, v, # bsz, heads, chunk_size, dim
past_kv, # bsz, heads, dim, dim
decay_mask, # heads, chunk_size, chunk_size
chunk_decay, # heads, 1, 1
inner_decay, # heads, chunk_size, 1
current_decay, # heads, chunk_size, 1
):
retention = q @ k.transpose(-1, -2)
retention = retention * decay_mask
inner_retention = retention @ v
cross_retention = (q @ past_kv) * inner_decay
output = inner_retention + cross_retention
current_kv = chunk_decay * past_kv + (k * current_decay).transpose(-1, -2) @ v
return output, current_kv
def decay_grid(seq_len, decay, device=None):
i, j = torch.meshgrid(
torch.arange(seq_len, device=device),
torch.arange(seq_len, device=device),
indexing="ij",
)
return decay**(i-j)*(i-j>=0)
def retention(q, k, v, decay, chunk_size=None):
is_parallel = chunk_size is None
is_recurrent = chunk_size is not None and chunk_size == 1
is_chunked = chunk_size is not None and chunk_size > 1
if not isinstance(decay, torch.Tensor):
decay = torch.tensor(decay, device=q.device)
if decay.dim() <=1:
decay = decay.unsqueeze(-1).unsqueeze(-1)
if decay.size(0) == 1:
decay = decay.repeat(q.size(1), 1, 1)
b, h, seq, dim = q.size()
v_dim = v.size(-1)
if is_parallel:
decay_mask = decay_grid(seq, decay, device=q.device)
output = parallel_retention(q, k, v, decay_mask)
elif is_recurrent:
outputs = []
past_kv = 0
for idx in range(seq):
out, past_kv = recurrent_retention(
q[..., idx, :],
k[..., idx, :],
v[..., idx, :],
past_kv,
decay
)
outputs.append(out.unsqueeze(-2))
output = torch.concatenate(outputs, dim=-2)
elif is_chunked:
outputs = []
past_kv = torch.zeros(b, h, dim, v_dim, device=k.device)
decay_mask = decay_grid(seq, decay, device=q.device)
chunk_decay = decay**chunk_size
current_decay = decay**torch.arange(chunk_size, device=q.device).unsqueeze(-1)
inner_decay = decay*current_decay
for idx in range(seq//chunk_size):
out, past_kv = chunked_retention(
q[..., idx*chunk_size:(idx+1)*chunk_size, :],
k[..., idx*chunk_size:(idx+1)*chunk_size, :],
v[..., idx*chunk_size:(idx+1)*chunk_size, :],
past_kv,
decay_mask,
chunk_decay,
inner_decay,
current_decay.flip(-2),
)
outputs.append(out)
output = torch.concatenate(outputs, dim=-2)
else:
raise ValueError("chunk_size must be None or a positive integer")
return output
class Retention(nn.Module):
def __init__(self, dim=512, heads=8, gammas=0.995, gated=False):
self.dim = dim
self.heads = heads
if not isinstance(gammas, torch.Tensor):
gammas = torch.tensor(gammas)
gammas = gammas.view(-1)
if gammas.size(0) == 1:
gammas = gammas.repeat(self.heads)
else:
assert gammas.size(0) == self.heads
self.gammas = gammas.unsqueeze(-1).unsqueeze(-1)
self.gated = gated
if gated:
self.x_proj = nn.Linear(dim, dim)
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_k = nn.Linear(dim, dim, bias=False)
self.to_v = nn.Linear(dim, dim, bias=False)
self.out_norm = nn.GroupNorm(self.heads, dim)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = rearrange(q, "b l (h d) -> b h l d", d = self.dim, h = self.heads)
k = rearrange(k, "b l (h d) -> b h l d", d = self.dim, h = self.heads)
v = rearrange(v, "b l (h d) -> b h l d", d = self.dim, h = self.heads)
ret = retention(q, k, v, self.gammas).transpose(1, 3)
ret = self.out_norm(ret).transpose(3, 1)
if self.gated:
ret = ret * F.silu(self.x_proj(x))
return self.to_out(ret)
if __name__ == '__main__':
# check consistency between parallel and recurrent
test_q = torch.randn(1, 40, 64*6, 128).cuda()
test_k = torch.randn(1, 40, 64*6, 128).cuda()
test_v = torch.randn(1, 40, 64*6, 128).cuda()
decay = 1 - 2**(-5-torch.arange(40, device=test_q.device).float()/10)
test_recurrent = retention(test_q, test_k, test_v, decay, 1)
test_parallel = retention(test_q, test_k, test_v, decay)
test_chunked = retention(test_q, test_k, test_v, decay, 64)
print(torch.max(F.mse_loss(test_recurrent, test_parallel, reduction="none")))
print(torch.max(F.mse_loss(test_chunked, test_parallel, reduction="none")))
print(torch.max(F.mse_loss(test_chunked, test_recurrent, reduction="none")))
@KohakuBlueleaf
Copy link
Author

I use different equation for chunkwise retention to ensure its consistancy
if you are interested in the detail
check this:
https://hackmd.io/@KBlueLeaf/rk442m852

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment