Skip to content

Instantly share code, notes, and snippets.

@tlkahn
Created January 29, 2025 21:45
Show Gist options
  • Save tlkahn/0d1f72a0534e0a3e1b3a0c9ee61981b9 to your computer and use it in GitHub Desktop.
Save tlkahn/0d1f72a0534e0a3e1b3a0c9ee61981b9 to your computer and use it in GitHub Desktop.
RLHF prototype
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Mock dataset
def create_mock_dataset(vocab_size, seq_len, num_pairs):
    preferred = torch.randint(0, vocab_size, (num_pairs, seq_len))
    non_preferred = torch.randint(0, vocab_size, (num_pairs, seq_len))
    return [(preferred[i], non_preferred[i]) for i in range(num_pairs)]

class LanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_layers, batch_first=True)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.transformer(x)
        return self.fc(x)

    def generate(self, prompt, max_length):
        with torch.no_grad():
            for _ in range(max_length):
                output = self(prompt)
                next_token = output[:, -1:].argmax(dim=-1)
                prompt = torch.cat([prompt, next_token], dim=1)
        return prompt

class RewardModel(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.lstm = nn.LSTM(d_model, d_model, batch_first=True)
        self.fc = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.embedding(x)
        _, (h, _) = self.lstm(x)
        return self.fc(h.squeeze(0))


def compute_reward_loss(reward_model, preferred, non_preferred):
    preferred_reward = reward_model(preferred.unsqueeze(0))
    non_preferred_reward = reward_model(non_preferred.unsqueeze(0))
    loss = -F.logsigmoid(preferred_reward - non_preferred_reward)
    return loss.mean()

def compute_kl_divergence(current_lm, original_lm, input_data):
    with torch.no_grad():
        original_logits = original_lm(input_data)
    current_logits = current_lm(input_data)

    original_probs = F.softmax(original_logits, dim=-1)
    kl_div = F.kl_div(
        F.log_softmax(current_logits, dim=-1),
        original_probs,
        reduction='batchmean'
    )
    return kl_div


# Hyperparameters
vocab_size = 10000
d_model = 512
nhead = 8
num_layers = 6
num_epochs = 10
rl_steps = 100
kl_weight = 0.1
seq_len = 20
num_pairs = 100

# Initialize models
lm = LanguageModel(vocab_size, d_model, nhead, num_layers)
rm = RewardModel(vocab_size, d_model)
original_lm = LanguageModel(vocab_size, d_model, nhead, num_layers)
original_lm.load_state_dict(lm.state_dict()) # Initialize original_lm same as lm

# Optimizers
lm_optimizer = optim.Adam(lm.parameters())
rm_optimizer = optim.Adam(rm.parameters())


# Training loop
human_preference_dataset = create_mock_dataset(vocab_size, seq_len, num_pairs)
for epoch in range(num_epochs):
    # Train reward model on human preferences
    for preferred, non_preferred in human_preference_dataset:
        rm_optimizer.zero_grad()
        rm_loss = compute_reward_loss(rm, preferred, non_preferred)
        rm_loss.backward()
        rm_optimizer.step()

    # RL training of language model
    for _ in range(rl_steps):
        lm_optimizer.zero_grad()

        prompt = torch.randint(0, vocab_size, (1, 10))  # Random prompt
        response = lm.generate(prompt, max_length=20)

        reward = rm(response)

        pg_loss = -reward.mean()

        kl_div = compute_kl_divergence(lm, original_lm, response)

        loss = pg_loss + kl_weight * kl_div

        loss.backward()
        lm_optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs} completed")

The RL part is primarily in the second loop of the training process:

# RL training of language model
for _ in range(rl_steps):
    lm_optimizer.zero_grad()

    prompt = torch.randint(0, vocab_size, (1, 10))  # Random prompt
    response = lm.generate(prompt, max_length=20)

    reward = rm(response)

    pg_loss = -reward.mean()

    kl_div = compute_kl_divergence(lm, original_lm, response)

    loss = pg_loss + kl_weight * kl_div

    loss.backward()
    lm_optimizer.step()

This implements a simplified version of policy gradient RL:

  1. The language model generates a response (action)
  2. The reward model evaluates this response (reward)
  3. The loss is computed as negative reward (policy gradient)
  4. KL divergence is added to prevent excessive drift from the original model
  5. The model is updated to maximize the expected reward
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment