Skip to content

Instantly share code, notes, and snippets.

@smpanaro
Last active June 17, 2025 16:41
Show Gist options
  • Save smpanaro/5d301eb70b258c5e87b9017b5d528092 to your computer and use it in GitHub Desktop.
Save smpanaro/5d301eb70b258c5e87b9017b5d528092 to your computer and use it in GitHub Desktop.
FPTQuant Query-Key Transform Fusion
import torch
from torch import nn
"""
Proof-of-concept that you can fuse a transformation matrix
into q_proj and k_proj without changing model outputs.
As described in FPTQuant: https://arxiv.org/pdf/2506.04985
"""
torch.random.manual_seed(42)
torch.set_grad_enabled(False)
seq = 32
dim = 768
x = torch.randn(seq, dim)
# RoPE matrix as in RoFormer Section 3.2.2
# ...I know I know, don't reimplement RoPE, this is very very close to
# the torchtune implementation which you can swap in if don't believe me:
# from torchtune.modules import RotaryPositionalEmbeddings
# tt_rope = RotaryPositionalEmbeddings(dim)
# def RoPE(x):
# return tt_rope(x.unsqueeze(0).unsqueeze(-2)).squeeze()
thetas = torch.tensor(10_000).pow(-(torch.arange(0, dim, 2).float() ) / dim)
phi = torch.arange(seq).unsqueeze(1).float() * thetas.unsqueeze(0)
coses, sines = torch.cos(phi), torch.sin(phi)
full_rope_matrix = torch.stack([
torch.block_diag(*[torch.tensor([[coses[s,i], -sines[s,i]],[sines[s,i], coses[s,i]]])
for i in range(dim//2)])
for s in range(seq)
])
def RoPE(x):
return (full_rope_matrix @ x.unsqueeze(-1)).squeeze()
# Original Model (Single Head)
q_proj = torch.nn.Linear(dim, dim)
k_proj = torch.nn.Linear(dim, dim)
out_original = RoPE(q_proj(x)) @ RoPE(k_proj(x)).T
# Transformation matrices from https://arxiv.org/pdf/2506.04985 Section 3.1.1
# NOTE: HF computes RoPE differently, this won't work as a drop-in.
# See: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/3
scales = torch.randn(dim//2).repeat_interleave(2) # can be anything
thetas = torch.randn(dim//2).remainder(2*torch.pi) # can be anything
coses, sines = torch.cos(thetas), torch.sin(thetas)
R = torch.block_diag(*[torch.tensor([[coses[i], -sines[i]],[sines[i], coses[i]]]) for i in range(thetas.shape[0])])
Tk = scales.diag() @ R
Tq = scales.reciprocal().diag() @ R
# Transformed Model
q_proj.weight.data = Tq @ q_proj.weight.data
q_proj.bias.data = Tq @ q_proj.bias.data
k_proj.weight.data = Tk @ k_proj.weight.data
k_proj.bias.data = Tk @ k_proj.bias.data
out_transformed = RoPE(q_proj(x)) @ RoPE(k_proj(x)).T
# Compare
print("Original [email protected]:\n", out_original)
print("Transformed:\n",out_transformed)
print("Max Diff:", (out_original - out_transformed).abs().max())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment