Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created December 5, 2024 00:57
Show Gist options
  • Save cloneofsimo/1adaec43d08f4af936b02ca151b8e1a1 to your computer and use it in GitHub Desktop.
Save cloneofsimo/1adaec43d08f4af936b02ca151b8e1a1 to your computer and use it in GitHub Desktop.
twodimropevis
import torch
class TwoDimRotary(torch.nn.Module):
def __init__(self, dim, base=100, h = 128, w = 128):
super().__init__()
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / (dim)))
self.h = h
self.w = w
t_h = torch.arange(h).type_as(self.inv_freq)
t_w = torch.arange(w).type_as(self.inv_freq)
freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) # h, 1, d / 2
freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) # 1, w, d / 2
freqs_h = freqs_h.repeat(1, w, 1) # h, w, d / 2
freqs_w = freqs_w.repeat(h, 1, 1) # h, w, d / 2
freqs_hw = torch.cat([freqs_h, freqs_w], 2) # h, w, d
self.register_buffer("freqs_hw_cos", freqs_hw.cos())
self.register_buffer("freqs_hw_sin", freqs_hw.sin())
def forward(self, x, height_width = None, extend_with_register_tokens=0):
if height_width is not None:
this_h, this_w = height_width
else:
this_hw = x.shape[1]
this_h, this_w = int(this_hw ** 0.5), int(this_hw ** 0.5)
# randomly, we augment the height and width
start_h = torch.randint(0, self.h - this_h + 1, (1,)).item()
start_w = torch.randint(0, self.w - this_w + 1, (1,)).item()
cos = self.freqs_hw_cos[start_h:start_h+this_h, start_w:start_w+this_w]
sin = self.freqs_hw_sin[start_h:start_h+this_h, start_w:start_w+this_w]
cos = cos.clone().reshape(this_h * this_w, -1)
sin = sin.clone().reshape(this_h * this_w, -1)
# append N of zero-attn tokens
if extend_with_register_tokens > 0:
cos = torch.cat([torch.zeros(extend_with_register_tokens, cos.shape[1]), cos], 0)
sin = torch.cat([torch.zeros(extend_with_register_tokens, sin.shape[1]), sin], 0)
return cos[None, :, None, :], sin[None, :, None, :] # [1, T + N, 1, Attn-dim]
two_dim_rotary = TwoDimRotary(dim=32, h = 51, w = 32)
cos, sin = two_dim_rotary(torch.randn(1, 1024, 32), height_width = (51, 32), extend_with_register_tokens=16)
print(cos.shape, sin.shape)
# plot cos in 2d
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
# Reshape cos from (1, seq_len, 1, d) to (h, w, d)
h = 51
w = 32
d = cos.shape[-1]
cos_3d = cos[:, 16:, 0, :].reshape(h, w, d)
# Create 3D figure
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Plot the 3D volume using scatter with color mapping to cos values
scatter = ax.scatter(
*np.meshgrid(range(h), range(w), range(d), indexing='ij'),
c=cos_3d.reshape(-1),
cmap='RdBu',
alpha=1.0
)
# Add color bar
plt.colorbar(scatter, ax=ax)
# Set labels and title
ax.set_xlabel('Height')
ax.set_ylabel('Width')
ax.set_zlabel('Dimension')
plt.title('2D Rotary Embeddings Cosine Values')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment