Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Created November 24, 2024 16:57
Show Gist options
  • Save YouJiacheng/eabb84044636d34cb0d76d9dfb135a3a to your computer and use it in GitHub Desktop.
Save YouJiacheng/eabb84044636d34cb0d76d9dfb135a3a to your computer and use it in GitHub Desktop.
rope shift
import torch
import torch.nn as nn
import torch.nn.functional as F
class RoPE(nn.Module):
def __init__(
self,
dim,
max_seq_len: int = 4096,
base: int = 10_000,
):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
theta = torch.reciprocal(base ** (torch.arange(0, dim, 2)[: dim // 2] / dim))
assert theta.dtype == torch.float32
seq_idx = torch.arange(max_seq_len, dtype=theta.dtype, device=theta.device)
idx_theta = torch.einsum("i,j->ij", seq_idx, theta)
assert idx_theta.dtype == torch.float32
self.cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
assert self.cache.dtype == torch.float32
def forward(self, x: torch.Tensor, _slice: slice):
# [b, s, n_h, h_d]
seq_len = x.size(1)
assert seq_len == _slice.stop - _slice.start
assert _slice.stop <= self.max_seq_len
rope_cache = self.cache[_slice]
# [b, s, n_h, h_d // 2, 2]
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
# [1, s, 1, h_d // 2, 2
rope_cache = rope_cache.view(-1, seq_len, 1, xshaped.size(3), 2)
# [b, s, n_h, h_d // 2, 2]
x_out = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0]
- xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0]
+ xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
# [b, s, n_h, h_d]
x_out = x_out.flatten(3)
return x_out.type_as(x)
def main():
seq_len = 2048
d = 128
m = RoPE(d, max_seq_len=65536)
def rope(x: torch.Tensor, _slice: slice) -> torch.Tensor:
return m(x, _slice)
def sdpa_attn_any(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
# [b, n_h, s, h_d]
q, k = q.transpose(1, 2), k.transpose(1, 2)
Ev = seq_len
v = torch.eye(seq_len, dtype=q.dtype).reshape(1, 1, seq_len, Ev)
# [b, n_h, s, s]
attn = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
return attn.reshape(seq_len, seq_len)
def sdpa_attn_bf16(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
q, k = q.bfloat16(), k.bfloat16()
return sdpa_attn_any(q, k)
q = torch.randn(1, seq_len, 1, d, dtype=torch.float32)
k = torch.randn(1, seq_len, 1, d, dtype=torch.float32)
shift = 2048
rq_s = rope(q, slice(shift, shift + seq_len))
rk_s = rope(k, slice(shift, shift + seq_len))
rq_16 = rope(q, slice(16, 16 + seq_len))
rk_16 = rope(k, slice(16, 16 + seq_len))
sdpa = sdpa_attn_bf16
is_causal = True
a_s = sdpa(rq_s, rk_s)
a_16 = sdpa(rq_16, rk_16)
print((a_s - a_16).abs().sum(dim=0) / torch.arange(seq_len, 0, -1))
is_causal = False
a_s = sdpa(rq_s, rk_s)
a_16 = sdpa(rq_16, rk_16)
print((a_s - a_16).abs().sum(dim=0) / seq_len)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment