Skip to content

Instantly share code, notes, and snippets.

@tlkahn
Created February 2, 2025 03:16
Show Gist options
  • Save tlkahn/649e212c0fbaa74dfe4ee8c387a5b0b7 to your computer and use it in GitHub Desktop.
Save tlkahn/649e212c0fbaa74dfe4ee8c387a5b0b7 to your computer and use it in GitHub Desktop.
toyword2vec
# -*- coding: utf-8 -*-
"""toy-word2vec.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1Lk3Td9MXzndT0ld1iaCF6FbggKCzBa2Z
"""
import jax
import jax.numpy as jnp
from jax import grad, jit
from typing import Tuple
def skipgram_model(W: jnp.ndarray, context: jnp.ndarray, target: jnp.ndarray) -> float:
hidden = W[context]
output = jnp.dot(hidden, W.T)
prob = jax.nn.softmax(output)
loss = -jnp.log(prob[target])
return loss
@jit
def train_step(W: jnp.ndarray, context: jnp.ndarray, target: jnp.ndarray,
lr: float) -> jnp.ndarray:
loss_grad = grad(skipgram_model)(W, context, target)
return W - lr * loss_grad
def train(vocab_size: int, embed_dim: int,
context_target_pairs: list[Tuple[int, int]],
epochs: int = 5, lr: float = 0.1) -> jnp.ndarray:
key = jax.random.PRNGKey(0)
W = jax.random.normal(key, (vocab_size, embed_dim))
for _ in range(epochs):
for context, target in context_target_pairs:
W = train_step(W, context, target, lr)
return W
vocab = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog"]
vocab_size = len(vocab)
embed_dim = 2
word_to_index = {word: i for i, word in enumerate(vocab)}
context_target_pairs = [(word_to_index["the"], word_to_index["quick"]),
(word_to_index["quick"], word_to_index["brown"]),
(word_to_index["brown"], word_to_index["fox"]),
(word_to_index["fox"], word_to_index["jumps"]),
(word_to_index["jumps"], word_to_index["over"]),
(word_to_index["over"], word_to_index["lazy"]),
(word_to_index["lazy"], word_to_index["dog"])]
W = train(vocab_size, embed_dim, context_target_pairs)
for word, idx in word_to_index.items():
print(f"{word}: {W[idx]}")
import jax
import jax.numpy as jnp
import optax
import numpy as np
from functools import partial
# Preprocess data and create training pairs
corpus = [
"the quick brown fox jumps over the lazy dog",
"i love machine learning",
"jax is a great library",
"word embeddings are useful"
]
# Build vocabulary
words = []
for sentence in corpus:
words.extend(sentence.lower().split())
vocab = sorted(set(words))
word2idx = {word: idx for idx, word in enumerate(vocab)}
vocab_size = len(vocab)
# Generate training data (target, context) pairs
window_size = 2
training_data = []
for sentence in corpus:
tokens = sentence.lower().split()
token_indices = [word2idx[word] for word in tokens]
for i, target in enumerate(token_indices):
start = max(0, i - window_size)
end = min(len(token_indices), i + window_size + 1)
for j in range(start, end):
if j != i and j < len(token_indices):
context = token_indices[j]
training_data.append((target, context))
# Convert to numpy arrays
targets = np.array([t for t, c in training_data])
contexts = np.array([c for t, c in training_data])
# Hyperparameters
embedding_dim = 32
learning_rate = 0.1
num_epochs = 100
batch_size = 16
# Initialize parameters
key = jax.random.PRNGKey(42)
key, target_key, context_key = jax.random.split(key, 3)
params = {
'target_embeddings': jax.random.normal(target_key, (vocab_size, embedding_dim)) * 0.1,
'context_embeddings': jax.random.normal(context_key, (vocab_size, embedding_dim)) * 0.1,
}
# Define loss function
def loss_fn(params, targets, contexts):
target_embeds = params['target_embeddings'][targets]
logits = jnp.dot(target_embeds, params['context_embeddings'].T)
log_p = jax.nn.log_softmax(logits)
loss = -jnp.mean(log_p[jnp.arange(len(contexts)), contexts])
return loss
# Initialize optimizer
optimizer = optax.sgd(learning_rate)
opt_state = optimizer.init(params)
# Define update step
@jax.jit
def update_step(params, opt_state, batch_targets, batch_contexts):
loss, grads = jax.value_and_grad(loss_fn)(params, batch_targets, batch_contexts)
updates, new_opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt_state, loss
# Training loop
for epoch in range(num_epochs):
# Shuffle data
perm = np.random.permutation(len(targets))
shuffled_targets = targets[perm]
shuffled_contexts = contexts[perm]
# Process batches
epoch_loss = 0.0
num_batches = len(shuffled_targets) // batch_size
for i in range(num_batches):
start = i * batch_size
end = start + batch_size
batch_t = shuffled_targets[start:end]
batch_c = shuffled_contexts[start:end]
params, opt_state, batch_loss = update_step(params, opt_state, batch_t, batch_c)
epoch_loss += batch_loss
# Print progress
if epoch % 10 == 0:
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")
# Get final embeddings
word_embeddings = params['target_embeddings']
# Example usage: find similar words
def cosine_similarity(vec1, vec2):
return jnp.dot(vec1, vec2) / (jnp.linalg.norm(vec1) * jnp.linalg.norm(vec2))
def most_similar(word, embeddings, top_n=5):
idx = word2idx[word]
word_vec = embeddings[idx]
similarities = []
for i, vec in enumerate(embeddings):
if i != idx:
similarities.append((i, float(cosine_similarity(word_vec, vec))))
similarities.sort(key=lambda x: -x[1])
return [(list(word2idx.keys())[i], sim) for i, sim in similarities[:top_n]]
print("\nMost similar to 'jax':")
for word, sim in most_similar('jax', word_embeddings):
print(f"{word}: {sim:.3f}")
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# Preprocess data and create training pairs
corpus = [
"the quick brown fox jumps over the lazy dog",
"i love machine learning",
"jax is a great library",
"word embeddings are useful"
]
# Build vocabulary
words = []
for sentence in corpus:
words.extend(sentence.lower().split())
vocab = sorted(set(words))
word2idx = {word: idx for idx, word in enumerate(vocab)}
vocab_size = len(vocab)
# Generate training data (target, context) pairs
window_size = 2
training_data = []
for sentence in corpus:
tokens = sentence.lower().split()
token_indices = [word2idx[word] for word in tokens]
for i, target in enumerate(token_indices):
start = max(0, i - window_size)
end = min(len(token_indices), i + window_size + 1)
for j in range(start, end):
if j != i and j < len(token_indices):
context = token_indices[j]
training_data.append((target, context))
# Convert to numpy arrays
targets = np.array([t for t, c in training_data])
contexts = np.array([c for t, c in training_data])
# Hyperparameters
embedding_dim = 32
learning_rate = 0.1
num_epochs = 100
batch_size = 16
class Word2Vec(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(Word2Vec, self).__init__()
self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
nn.init.normal_(self.target_embeddings.weight, std=0.1)
nn.init.normal_(self.context_embeddings.weight, std=0.1)
def forward(self, targets, contexts):
target_embeds = self.target_embeddings(targets)
logits = torch.matmul(target_embeds, self.context_embeddings.weight.T)
log_probs = torch.nn.functional.log_softmax(logits, dim=1)
return -torch.mean(log_probs[torch.arange(len(contexts)), contexts])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Word2Vec(vocab_size, embedding_dim).to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# Training loop
for epoch in range(num_epochs):
perm = np.random.permutation(len(targets))
shuffled_targets = targets[perm]
shuffled_contexts = contexts[perm]
epoch_loss = 0.0
num_batches = len(shuffled_targets) // batch_size
for i in range(num_batches):
start = i * batch_size
end = start + batch_size
batch_t = torch.tensor(shuffled_targets[start:end]).to(device)
batch_c = torch.tensor(shuffled_contexts[start:end]).to(device)
optimizer.zero_grad()
loss = model(batch_t, batch_c)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if epoch % 10 == 0:
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")
# Get final embeddings
word_embeddings = model.target_embeddings.weight.detach().cpu().numpy()
# Example usage: find similar words
def cosine_similarity(vec1, vec2):
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
def most_similar(word, embeddings, top_n=5):
idx = word2idx[word]
word_vec = embeddings[idx]
similarities = []
for i, vec in enumerate(embeddings):
if i != idx:
sim = cosine_similarity(word_vec, vec)
similarities.append((i, sim))
similarities.sort(key=lambda x: -x[1])
idx2word = {idx: word for word, idx in word2idx.items()}
return [(idx2word[i], sim) for i, sim in similarities[:top_n]]
print("\nMost similar to 'jax':")
for word, sim in most_similar('jax', word_embeddings):
print(f"{word}: {sim:.3f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment