Last active
June 17, 2025 16:41
-
-
Save smpanaro/5d301eb70b258c5e87b9017b5d528092 to your computer and use it in GitHub Desktop.
FPTQuant Query-Key Transform Fusion
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
""" | |
Proof-of-concept that you can fuse a transformation matrix | |
into q_proj and k_proj without changing model outputs. | |
As described in FPTQuant: https://arxiv.org/pdf/2506.04985 | |
""" | |
torch.random.manual_seed(42) | |
torch.set_grad_enabled(False) | |
seq = 32 | |
dim = 768 | |
x = torch.randn(seq, dim) | |
# RoPE matrix as in RoFormer Section 3.2.2 | |
# ...I know I know, don't reimplement RoPE, this is very very close to | |
# the torchtune implementation which you can swap in if don't believe me: | |
# from torchtune.modules import RotaryPositionalEmbeddings | |
# tt_rope = RotaryPositionalEmbeddings(dim) | |
# def RoPE(x): | |
# return tt_rope(x.unsqueeze(0).unsqueeze(-2)).squeeze() | |
thetas = torch.tensor(10_000).pow(-(torch.arange(0, dim, 2).float() ) / dim) | |
phi = torch.arange(seq).unsqueeze(1).float() * thetas.unsqueeze(0) | |
coses, sines = torch.cos(phi), torch.sin(phi) | |
full_rope_matrix = torch.stack([ | |
torch.block_diag(*[torch.tensor([[coses[s,i], -sines[s,i]],[sines[s,i], coses[s,i]]]) | |
for i in range(dim//2)]) | |
for s in range(seq) | |
]) | |
def RoPE(x): | |
return (full_rope_matrix @ x.unsqueeze(-1)).squeeze() | |
# Original Model (Single Head) | |
q_proj = torch.nn.Linear(dim, dim) | |
k_proj = torch.nn.Linear(dim, dim) | |
out_original = RoPE(q_proj(x)) @ RoPE(k_proj(x)).T | |
# Transformation matrices from https://arxiv.org/pdf/2506.04985 Section 3.1.1 | |
# NOTE: HF computes RoPE differently, this won't work as a drop-in. | |
# See: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/3 | |
scales = torch.randn(dim//2).repeat_interleave(2) # can be anything | |
thetas = torch.randn(dim//2).remainder(2*torch.pi) # can be anything | |
coses, sines = torch.cos(thetas), torch.sin(thetas) | |
R = torch.block_diag(*[torch.tensor([[coses[i], -sines[i]],[sines[i], coses[i]]]) for i in range(thetas.shape[0])]) | |
Tk = scales.diag() @ R | |
Tq = scales.reciprocal().diag() @ R | |
# Transformed Model | |
q_proj.weight.data = Tq @ q_proj.weight.data | |
q_proj.bias.data = Tq @ q_proj.bias.data | |
k_proj.weight.data = Tk @ k_proj.weight.data | |
k_proj.bias.data = Tk @ k_proj.bias.data | |
out_transformed = RoPE(q_proj(x)) @ RoPE(k_proj(x)).T | |
# Compare | |
print("Original [email protected]:\n", out_original) | |
print("Transformed:\n",out_transformed) | |
print("Max Diff:", (out_original - out_transformed).abs().max()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment