Skip to content

Instantly share code, notes, and snippets.

@av
Created February 14, 2025 19:59
Show Gist options
  • Save av/0bbc2ca24d8ad13c58da1390ef2a7d51 to your computer and use it in GitHub Desktop.
Save av/0bbc2ca24d8ad13c58da1390ef2a7d51 to your computer and use it in GitHub Desktop.
Qwen-2.5-0.5B-stripes
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM
import gc
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("kz919/QwQ-0.5B-Distilled-SFT")
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111, projection='3d')
z_offset = 0
layer_count = 0
cmap = plt.cm.viridis
for name, param in model.named_parameters():
if param.shape == (896, 896) and layer_count < 3:
data = param.detach().cpu().numpy()
hist, bin_edges = np.histogram(data, bins=64)
cumulative_hist = np.cumsum(hist)
total_values = cumulative_hist[-1]
median_bin = np.searchsorted(cumulative_hist, total_values // 2)
mask = (data >= bin_edges[median_bin]) & (data < bin_edges[median_bin + 1])
valid_points = np.where(mask)
scatter = ax.scatter(
valid_points[0],
valid_points[1],
data[valid_points] + z_offset,
c=data[valid_points],
cmap=cmap,
alpha=0.1,
s=0.2,
)
z_offset += 0.2
layer_count += 1
# Clean up
del data, mask, valid_points
gc.collect()
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Values + Offset')
ax.view_init(elev=22.)
plt.colorbar(scatter)
plt.title(f'Stacked visualization of {layer_count} layers (896x896)')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment