Created
December 5, 2024 00:57
-
-
Save cloneofsimo/1adaec43d08f4af936b02ca151b8e1a1 to your computer and use it in GitHub Desktop.
twodimropevis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class TwoDimRotary(torch.nn.Module): | |
def __init__(self, dim, base=100, h = 128, w = 128): | |
super().__init__() | |
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / (dim))) | |
self.h = h | |
self.w = w | |
t_h = torch.arange(h).type_as(self.inv_freq) | |
t_w = torch.arange(w).type_as(self.inv_freq) | |
freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) # h, 1, d / 2 | |
freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) # 1, w, d / 2 | |
freqs_h = freqs_h.repeat(1, w, 1) # h, w, d / 2 | |
freqs_w = freqs_w.repeat(h, 1, 1) # h, w, d / 2 | |
freqs_hw = torch.cat([freqs_h, freqs_w], 2) # h, w, d | |
self.register_buffer("freqs_hw_cos", freqs_hw.cos()) | |
self.register_buffer("freqs_hw_sin", freqs_hw.sin()) | |
def forward(self, x, height_width = None, extend_with_register_tokens=0): | |
if height_width is not None: | |
this_h, this_w = height_width | |
else: | |
this_hw = x.shape[1] | |
this_h, this_w = int(this_hw ** 0.5), int(this_hw ** 0.5) | |
# randomly, we augment the height and width | |
start_h = torch.randint(0, self.h - this_h + 1, (1,)).item() | |
start_w = torch.randint(0, self.w - this_w + 1, (1,)).item() | |
cos = self.freqs_hw_cos[start_h:start_h+this_h, start_w:start_w+this_w] | |
sin = self.freqs_hw_sin[start_h:start_h+this_h, start_w:start_w+this_w] | |
cos = cos.clone().reshape(this_h * this_w, -1) | |
sin = sin.clone().reshape(this_h * this_w, -1) | |
# append N of zero-attn tokens | |
if extend_with_register_tokens > 0: | |
cos = torch.cat([torch.zeros(extend_with_register_tokens, cos.shape[1]), cos], 0) | |
sin = torch.cat([torch.zeros(extend_with_register_tokens, sin.shape[1]), sin], 0) | |
return cos[None, :, None, :], sin[None, :, None, :] # [1, T + N, 1, Attn-dim] | |
two_dim_rotary = TwoDimRotary(dim=32, h = 51, w = 32) | |
cos, sin = two_dim_rotary(torch.randn(1, 1024, 32), height_width = (51, 32), extend_with_register_tokens=16) | |
print(cos.shape, sin.shape) | |
# plot cos in 2d | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from mpl_toolkits.mplot3d import Axes3D | |
# Reshape cos from (1, seq_len, 1, d) to (h, w, d) | |
h = 51 | |
w = 32 | |
d = cos.shape[-1] | |
cos_3d = cos[:, 16:, 0, :].reshape(h, w, d) | |
# Create 3D figure | |
fig = plt.figure(figsize=(10, 10)) | |
ax = fig.add_subplot(111, projection='3d') | |
# Plot the 3D volume using scatter with color mapping to cos values | |
scatter = ax.scatter( | |
*np.meshgrid(range(h), range(w), range(d), indexing='ij'), | |
c=cos_3d.reshape(-1), | |
cmap='RdBu', | |
alpha=1.0 | |
) | |
# Add color bar | |
plt.colorbar(scatter, ax=ax) | |
# Set labels and title | |
ax.set_xlabel('Height') | |
ax.set_ylabel('Width') | |
ax.set_zlabel('Dimension') | |
plt.title('2D Rotary Embeddings Cosine Values') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment