Skip to content

Instantly share code, notes, and snippets.

@attentionmech
Last active May 21, 2025 16:29
Show Gist options
  • Save attentionmech/c9edf9a79b3aca74dc5c2aef8cb7d2a0 to your computer and use it in GitHub Desktop.
Save attentionmech/c9edf9a79b3aca74dc5c2aef8cb7d2a0 to your computer and use it in GitHub Desktop.
activation dynamics vis
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from sklearn.decomposition import PCA
# === CONFIG ===
prompt = "1 2 3 4 5 "
num_generate = 200 # adjust if memory constrained
alpha_min = 0.05
device = "mps" # change to "cuda" or "cpu" as needed
# === Load GPT-2
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True).to(device).eval()
# === Generate + sync hidden states + text
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
generated_ids = input_ids.clone()
hidden_vectors = []
frame_texts = []
with torch.no_grad():
for _ in range(num_generate):
outputs = model(generated_ids, output_hidden_states=True)
# collect all 12 layer vectors for the current token
for layer_hidden in outputs.hidden_states[1:]:
last_vector = layer_hidden[0, -1, :]
hidden_vectors.append(last_vector.cpu().numpy())
# update the current text shown on screen
frame_texts.append(tokenizer.decode(generated_ids[0]))
# generate next token (sampling)
logits = outputs.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# === Project to 3D
hidden_matrix = np.stack(hidden_vectors)
points_3d = PCA(n_components=3).fit_transform(hidden_matrix)
# === Build animation frames
frames = []
cmap = colormaps["plasma"]
for i in range(num_generate):
end = (i + 1) * 12
segment = points_3d[:end]
alphas = [alpha_min + (1 - alpha_min) * (j / end) for j in range(end - 1)]
frames.append((segment, alphas, frame_texts[i]))
# === Plot & Animate
fig = plt.figure(figsize=(30, 20))
fig.subplots_adjust(top=0.80) # Leave space for text
ax = fig.add_subplot(111, projection='3d')
text_handle = ax.text2D(
0.5, 0.60, "", transform=ax.transAxes,
ha="center", va="top", fontsize=14, color='white', wrap=True
)
def update(frame_idx):
ax.cla()
ax.set_facecolor("#000000")
ax.axis('off')
segment, alphas, current_text = frames[frame_idx]
for i in range(len(segment) - 1):
x = [segment[i, 0], segment[i + 1, 0]]
y = [segment[i, 1], segment[i + 1, 1]]
z = [segment[i, 2], segment[i + 1, 2]]
ax.plot(x, y, z, color=cmap(i / len(points_3d)), linewidth=2, alpha=alphas[i])
last_tokens = tokenizer.decode(tokenizer.encode(current_text)[-10:])
text_handle.set_text(last_tokens)
ax.add_artist(text_handle)
return []
ani = FuncAnimation(fig, update, frames=len(frames), blit=False, interval=100, repeat=True)
# === Save or show
ani.save("gpt2_synced_trail.mp4", writer="ffmpeg", dpi=200)
# plt.show() # Uncomment to preview live instead of saving
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment