Last active
May 21, 2025 16:29
-
-
Save attentionmech/c9edf9a79b3aca74dc5c2aef8cb7d2a0 to your computer and use it in GitHub Desktop.
activation dynamics vis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import colormaps | |
from matplotlib.animation import FuncAnimation | |
from mpl_toolkits.mplot3d import Axes3D | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from sklearn.decomposition import PCA | |
# === CONFIG === | |
prompt = "1 2 3 4 5 " | |
num_generate = 200 # adjust if memory constrained | |
alpha_min = 0.05 | |
device = "mps" # change to "cuda" or "cpu" as needed | |
# === Load GPT-2 | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True).to(device).eval() | |
# === Generate + sync hidden states + text | |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device) | |
generated_ids = input_ids.clone() | |
hidden_vectors = [] | |
frame_texts = [] | |
with torch.no_grad(): | |
for _ in range(num_generate): | |
outputs = model(generated_ids, output_hidden_states=True) | |
# collect all 12 layer vectors for the current token | |
for layer_hidden in outputs.hidden_states[1:]: | |
last_vector = layer_hidden[0, -1, :] | |
hidden_vectors.append(last_vector.cpu().numpy()) | |
# update the current text shown on screen | |
frame_texts.append(tokenizer.decode(generated_ids[0])) | |
# generate next token (sampling) | |
logits = outputs.logits[:, -1, :] | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
generated_ids = torch.cat([generated_ids, next_token], dim=1) | |
# === Project to 3D | |
hidden_matrix = np.stack(hidden_vectors) | |
points_3d = PCA(n_components=3).fit_transform(hidden_matrix) | |
# === Build animation frames | |
frames = [] | |
cmap = colormaps["plasma"] | |
for i in range(num_generate): | |
end = (i + 1) * 12 | |
segment = points_3d[:end] | |
alphas = [alpha_min + (1 - alpha_min) * (j / end) for j in range(end - 1)] | |
frames.append((segment, alphas, frame_texts[i])) | |
# === Plot & Animate | |
fig = plt.figure(figsize=(30, 20)) | |
fig.subplots_adjust(top=0.80) # Leave space for text | |
ax = fig.add_subplot(111, projection='3d') | |
text_handle = ax.text2D( | |
0.5, 0.60, "", transform=ax.transAxes, | |
ha="center", va="top", fontsize=14, color='white', wrap=True | |
) | |
def update(frame_idx): | |
ax.cla() | |
ax.set_facecolor("#000000") | |
ax.axis('off') | |
segment, alphas, current_text = frames[frame_idx] | |
for i in range(len(segment) - 1): | |
x = [segment[i, 0], segment[i + 1, 0]] | |
y = [segment[i, 1], segment[i + 1, 1]] | |
z = [segment[i, 2], segment[i + 1, 2]] | |
ax.plot(x, y, z, color=cmap(i / len(points_3d)), linewidth=2, alpha=alphas[i]) | |
last_tokens = tokenizer.decode(tokenizer.encode(current_text)[-10:]) | |
text_handle.set_text(last_tokens) | |
ax.add_artist(text_handle) | |
return [] | |
ani = FuncAnimation(fig, update, frames=len(frames), blit=False, interval=100, repeat=True) | |
# === Save or show | |
ani.save("gpt2_synced_trail.mp4", writer="ffmpeg", dpi=200) | |
# plt.show() # Uncomment to preview live instead of saving |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment