Last active
June 16, 2023 15:44
-
-
Save Mihonarium/7b4b9a4a17c8f1b1c67dc143b9225d53 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
def top_1_acc(OV_circuit): | |
return ((OV_circuit.argmax(dim=0) == t.arange(0, OV_circuit.size(-1)).to(device)).sum() / OV_circuit.size(-1)).item() | |
def top_5_acc(OV_circuit): | |
return ((OV_circuit.topk(5, dim=0)[1] == t.arange(0, OV_circuit.size(-1)).to(device)).sum() / OV_circuit.size(-1)).item() | |
W_U = model.unembed.W_U.to(device) | |
W_E = model.embed.W_E.to(device) | |
W_o = t.randn(768, 64, requires_grad=True, device=device) | |
W_v = t.randn(64, 768, requires_grad=True, device=device) | |
optimizer = t.optim.AdamW([W_o, W_v], lr=0.001) | |
steps = 120000 | |
batch_size = 1024 | |
pbar = tqdm(range(steps)) | |
i = t.eye(batch_size).to(device) | |
i_diag = i.diag() # actually it's just t.ones but whatever | |
top_1 = 0 | |
top_5 = 0 | |
for step in pbar: | |
optimizer.zero_grad() # Clear the gradients of W_o and W_v | |
indices = t.randperm(W_U.size(0))[:batch_size].to(device) | |
W_U_subset = W_U[indices] | |
W_E_subset = W_E[:, indices] | |
combined_WoWv = W_U_subset @ W_o @ W_v @ W_E_subset | |
softmax = t.softmax(combined_WoWv, dim=0).diag() | |
softmax[softmax > 0.9] = 1 | |
loss = t.nn.functional.mse_loss(softmax, | |
i_diag) + t.nn.functional.mse_loss( | |
combined_WoWv, i) / 160 | |
# the second mse_loss actually just penalizes large activations | |
# this keeps the activations from exploding, makes them similar in magnitute to the head's | |
loss.backward() | |
optimizer.step() | |
if step % 500 == 0: | |
pbar.set_postfix({'info': "Loss: {:.4f}, Top-1: {:.4f}, Top-5: {:.4f}".format(loss.item(), | |
top_1, top_5)}) | |
if step % 5000 == 0: | |
# print() | |
try: | |
t.cuda.empty_cache() | |
del OV_circuit_full | |
except: | |
pass | |
OV_circuit_full = model.unembed.W_U @ W_o @ W_v @ model.embed.W_E | |
top_1 = top_1_acc(OV_circuit_full) | |
top_5 = top_5_acc(OV_circuit_full) | |
del OV_circuit_full | |
try: | |
t.cuda.empty_cache() | |
del OV_circuit_full | |
except: | |
pass | |
OV_circuit_full = model.unembed.W_U @ W_o @ W_v @ model.embed.W_E | |
print("Top 1 accuracy for the trained OV Circuit:", top_1_acc(OV_circuit_full)) | |
print("Top 5 accuracy for the trained OV Circuit:", top_5_acc(OV_circuit_full)) | |
try: | |
del OV_circuit_full | |
except: | |
pass | |
print("The trained matrix rank (expected to be 64):", t.linalg.matrix_rank(W_o @ W_v).item()) | |
all_WoWvs = model.blocks[1].attn.W_O @ model.blocks[1].attn.W_V | |
print("The combained L1H4 + L1H10 W_O@W_V rank (expected to be 128):", t.linalg.matrix_rank(all_WoWvs[4] + all_WoWvs[10]).item()) | |
px.imshow( | |
to_numpy(W_U_subset @ (all_WoWvs[4] + all_WoWvs[10]) @ W_E_subset), | |
labels={"x": "Output tokens", "y": "Input tokens"}, | |
width=800, height=800, | |
title="two copying OV circuts from L1", | |
).show() | |
px.imshow( | |
to_numpy(combined_WoWv), | |
labels={"x": "Output tokens", "y": "Input tokens"}, | |
width=800, height=800, | |
title="trained token copying circut", | |
).show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment