Created
February 2, 2025 03:16
-
-
Save tlkahn/649e212c0fbaa74dfe4ee8c387a5b0b7 to your computer and use it in GitHub Desktop.
toyword2vec
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""toy-word2vec.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1Lk3Td9MXzndT0ld1iaCF6FbggKCzBa2Z | |
""" | |
import jax | |
import jax.numpy as jnp | |
from jax import grad, jit | |
from typing import Tuple | |
def skipgram_model(W: jnp.ndarray, context: jnp.ndarray, target: jnp.ndarray) -> float: | |
hidden = W[context] | |
output = jnp.dot(hidden, W.T) | |
prob = jax.nn.softmax(output) | |
loss = -jnp.log(prob[target]) | |
return loss | |
@jit | |
def train_step(W: jnp.ndarray, context: jnp.ndarray, target: jnp.ndarray, | |
lr: float) -> jnp.ndarray: | |
loss_grad = grad(skipgram_model)(W, context, target) | |
return W - lr * loss_grad | |
def train(vocab_size: int, embed_dim: int, | |
context_target_pairs: list[Tuple[int, int]], | |
epochs: int = 5, lr: float = 0.1) -> jnp.ndarray: | |
key = jax.random.PRNGKey(0) | |
W = jax.random.normal(key, (vocab_size, embed_dim)) | |
for _ in range(epochs): | |
for context, target in context_target_pairs: | |
W = train_step(W, context, target, lr) | |
return W | |
vocab = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog"] | |
vocab_size = len(vocab) | |
embed_dim = 2 | |
word_to_index = {word: i for i, word in enumerate(vocab)} | |
context_target_pairs = [(word_to_index["the"], word_to_index["quick"]), | |
(word_to_index["quick"], word_to_index["brown"]), | |
(word_to_index["brown"], word_to_index["fox"]), | |
(word_to_index["fox"], word_to_index["jumps"]), | |
(word_to_index["jumps"], word_to_index["over"]), | |
(word_to_index["over"], word_to_index["lazy"]), | |
(word_to_index["lazy"], word_to_index["dog"])] | |
W = train(vocab_size, embed_dim, context_target_pairs) | |
for word, idx in word_to_index.items(): | |
print(f"{word}: {W[idx]}") | |
import jax | |
import jax.numpy as jnp | |
import optax | |
import numpy as np | |
from functools import partial | |
# Preprocess data and create training pairs | |
corpus = [ | |
"the quick brown fox jumps over the lazy dog", | |
"i love machine learning", | |
"jax is a great library", | |
"word embeddings are useful" | |
] | |
# Build vocabulary | |
words = [] | |
for sentence in corpus: | |
words.extend(sentence.lower().split()) | |
vocab = sorted(set(words)) | |
word2idx = {word: idx for idx, word in enumerate(vocab)} | |
vocab_size = len(vocab) | |
# Generate training data (target, context) pairs | |
window_size = 2 | |
training_data = [] | |
for sentence in corpus: | |
tokens = sentence.lower().split() | |
token_indices = [word2idx[word] for word in tokens] | |
for i, target in enumerate(token_indices): | |
start = max(0, i - window_size) | |
end = min(len(token_indices), i + window_size + 1) | |
for j in range(start, end): | |
if j != i and j < len(token_indices): | |
context = token_indices[j] | |
training_data.append((target, context)) | |
# Convert to numpy arrays | |
targets = np.array([t for t, c in training_data]) | |
contexts = np.array([c for t, c in training_data]) | |
# Hyperparameters | |
embedding_dim = 32 | |
learning_rate = 0.1 | |
num_epochs = 100 | |
batch_size = 16 | |
# Initialize parameters | |
key = jax.random.PRNGKey(42) | |
key, target_key, context_key = jax.random.split(key, 3) | |
params = { | |
'target_embeddings': jax.random.normal(target_key, (vocab_size, embedding_dim)) * 0.1, | |
'context_embeddings': jax.random.normal(context_key, (vocab_size, embedding_dim)) * 0.1, | |
} | |
# Define loss function | |
def loss_fn(params, targets, contexts): | |
target_embeds = params['target_embeddings'][targets] | |
logits = jnp.dot(target_embeds, params['context_embeddings'].T) | |
log_p = jax.nn.log_softmax(logits) | |
loss = -jnp.mean(log_p[jnp.arange(len(contexts)), contexts]) | |
return loss | |
# Initialize optimizer | |
optimizer = optax.sgd(learning_rate) | |
opt_state = optimizer.init(params) | |
# Define update step | |
@jax.jit | |
def update_step(params, opt_state, batch_targets, batch_contexts): | |
loss, grads = jax.value_and_grad(loss_fn)(params, batch_targets, batch_contexts) | |
updates, new_opt_state = optimizer.update(grads, opt_state) | |
new_params = optax.apply_updates(params, updates) | |
return new_params, new_opt_state, loss | |
# Training loop | |
for epoch in range(num_epochs): | |
# Shuffle data | |
perm = np.random.permutation(len(targets)) | |
shuffled_targets = targets[perm] | |
shuffled_contexts = contexts[perm] | |
# Process batches | |
epoch_loss = 0.0 | |
num_batches = len(shuffled_targets) // batch_size | |
for i in range(num_batches): | |
start = i * batch_size | |
end = start + batch_size | |
batch_t = shuffled_targets[start:end] | |
batch_c = shuffled_contexts[start:end] | |
params, opt_state, batch_loss = update_step(params, opt_state, batch_t, batch_c) | |
epoch_loss += batch_loss | |
# Print progress | |
if epoch % 10 == 0: | |
avg_loss = epoch_loss / num_batches | |
print(f"Epoch {epoch}, Loss: {avg_loss:.4f}") | |
# Get final embeddings | |
word_embeddings = params['target_embeddings'] | |
# Example usage: find similar words | |
def cosine_similarity(vec1, vec2): | |
return jnp.dot(vec1, vec2) / (jnp.linalg.norm(vec1) * jnp.linalg.norm(vec2)) | |
def most_similar(word, embeddings, top_n=5): | |
idx = word2idx[word] | |
word_vec = embeddings[idx] | |
similarities = [] | |
for i, vec in enumerate(embeddings): | |
if i != idx: | |
similarities.append((i, float(cosine_similarity(word_vec, vec)))) | |
similarities.sort(key=lambda x: -x[1]) | |
return [(list(word2idx.keys())[i], sim) for i, sim in similarities[:top_n]] | |
print("\nMost similar to 'jax':") | |
for word, sim in most_similar('jax', word_embeddings): | |
print(f"{word}: {sim:.3f}") | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import numpy as np | |
# Preprocess data and create training pairs | |
corpus = [ | |
"the quick brown fox jumps over the lazy dog", | |
"i love machine learning", | |
"jax is a great library", | |
"word embeddings are useful" | |
] | |
# Build vocabulary | |
words = [] | |
for sentence in corpus: | |
words.extend(sentence.lower().split()) | |
vocab = sorted(set(words)) | |
word2idx = {word: idx for idx, word in enumerate(vocab)} | |
vocab_size = len(vocab) | |
# Generate training data (target, context) pairs | |
window_size = 2 | |
training_data = [] | |
for sentence in corpus: | |
tokens = sentence.lower().split() | |
token_indices = [word2idx[word] for word in tokens] | |
for i, target in enumerate(token_indices): | |
start = max(0, i - window_size) | |
end = min(len(token_indices), i + window_size + 1) | |
for j in range(start, end): | |
if j != i and j < len(token_indices): | |
context = token_indices[j] | |
training_data.append((target, context)) | |
# Convert to numpy arrays | |
targets = np.array([t for t, c in training_data]) | |
contexts = np.array([c for t, c in training_data]) | |
# Hyperparameters | |
embedding_dim = 32 | |
learning_rate = 0.1 | |
num_epochs = 100 | |
batch_size = 16 | |
class Word2Vec(nn.Module): | |
def __init__(self, vocab_size, embedding_dim): | |
super(Word2Vec, self).__init__() | |
self.target_embeddings = nn.Embedding(vocab_size, embedding_dim) | |
self.context_embeddings = nn.Embedding(vocab_size, embedding_dim) | |
nn.init.normal_(self.target_embeddings.weight, std=0.1) | |
nn.init.normal_(self.context_embeddings.weight, std=0.1) | |
def forward(self, targets, contexts): | |
target_embeds = self.target_embeddings(targets) | |
logits = torch.matmul(target_embeds, self.context_embeddings.weight.T) | |
log_probs = torch.nn.functional.log_softmax(logits, dim=1) | |
return -torch.mean(log_probs[torch.arange(len(contexts)), contexts]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = Word2Vec(vocab_size, embedding_dim).to(device) | |
optimizer = optim.SGD(model.parameters(), lr=learning_rate) | |
# Training loop | |
for epoch in range(num_epochs): | |
perm = np.random.permutation(len(targets)) | |
shuffled_targets = targets[perm] | |
shuffled_contexts = contexts[perm] | |
epoch_loss = 0.0 | |
num_batches = len(shuffled_targets) // batch_size | |
for i in range(num_batches): | |
start = i * batch_size | |
end = start + batch_size | |
batch_t = torch.tensor(shuffled_targets[start:end]).to(device) | |
batch_c = torch.tensor(shuffled_contexts[start:end]).to(device) | |
optimizer.zero_grad() | |
loss = model(batch_t, batch_c) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
if epoch % 10 == 0: | |
avg_loss = epoch_loss / num_batches | |
print(f"Epoch {epoch}, Loss: {avg_loss:.4f}") | |
# Get final embeddings | |
word_embeddings = model.target_embeddings.weight.detach().cpu().numpy() | |
# Example usage: find similar words | |
def cosine_similarity(vec1, vec2): | |
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) | |
def most_similar(word, embeddings, top_n=5): | |
idx = word2idx[word] | |
word_vec = embeddings[idx] | |
similarities = [] | |
for i, vec in enumerate(embeddings): | |
if i != idx: | |
sim = cosine_similarity(word_vec, vec) | |
similarities.append((i, sim)) | |
similarities.sort(key=lambda x: -x[1]) | |
idx2word = {idx: word for word, idx in word2idx.items()} | |
return [(idx2word[i], sim) for i, sim in similarities[:top_n]] | |
print("\nMost similar to 'jax':") | |
for word, sim in most_similar('jax', word_embeddings): | |
print(f"{word}: {sim:.3f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment