Skip to content

Instantly share code, notes, and snippets.

@JacobFV
Last active September 16, 2024 08:47
Show Gist options
  • Save JacobFV/7785fd6175c74f9274008cfd7ec5147c to your computer and use it in GitHub Desktop.
Save JacobFV/7785fd6175c74f9274008cfd7ec5147c to your computer and use it in GitHub Desktop.
open1

open1

This is a start. It doesn't work atm.

Running

$ python open1.py --epochs 3 --use_mcts --multi_agent --progressive_training --visualize

2024-09-16 07:51:04.977649: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-16 07:51:04.995359: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-16 07:51:05.016441: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-16 07:51:05.022856: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-16 07:51:05.038285: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-16 07:51:06.135730: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[2024-09-16 07:51:08,470] INFO - Starting training of Open1 model.
[2024-09-16 07:51:17,106] INFO - Starting training of the Reward Model.
[2024-09-16 07:51:17,109] INFO - Reward Model Epoch 1/1
[2024-09-16 07:51:34,592] INFO - Average Reward Model Loss: 0.1820
[2024-09-16 07:51:35,706] INFO - Reward Model training completed and saved.
[2024-09-16 07:51:35,708] INFO - Epoch 1/3
[2024-09-16 07:51:38,462] INFO - Batch 1/500, Loss: -0.3782, Policy Loss: -0.3782, KL Divergence: 0.0000
[2024-09-16 07:51:47,109] INFO - Batch 2/500, Loss: -0.0844, Policy Loss: -0.0844, KL Divergence: 0.0000
[2024-09-16 07:51:55,192] INFO - Batch 3/500, Loss: -0.3686, Policy Loss: -0.3686, KL Divergence: 0.0000
[2024-09-16 07:52:03,152] INFO - Batch 4/500, Loss: -0.0971, Policy Loss: -0.0971, KL Divergence: 0.0000
[2024-09-16 07:52:11,042] INFO - Batch 5/500, Loss: -0.1195, Policy Loss: -0.1195, KL Divergence: 0.0000
[2024-09-16 07:52:18,816] INFO - Batch 6/500, Loss: -0.5291, Policy Loss: -0.5291, KL Divergence: 0.0000
[2024-09-16 07:52:26,654] INFO - Batch 7/500, Loss: -0.0946, Policy Loss: -0.0946, KL Divergence: 0.0000
[2024-09-16 07:52:34,471] INFO - Batch 8/500, Loss: -0.0336, Policy Loss: -0.0336, KL Divergence: 0.0000
[2024-09-16 07:52:42,338] INFO - Batch 9/500, Loss: -0.1767, Policy Loss: -0.1767, KL Divergence: 0.0000
[2024-09-16 07:52:50,294] INFO - Batch 10/500, Loss: -0.0244, Policy Loss: -0.0244, KL Divergence: 0.0000
[2024-09-16 07:52:58,278] INFO - Batch 11/500, Loss: -0.2160, Policy Loss: -0.2160, KL Divergence: 0.0000
[2024-09-16 07:53:06,154] INFO - Batch 12/500, Loss: -0.1222, Policy Loss: -0.1222, KL Divergence: 0.0000
[2024-09-16 07:53:14,004] INFO - Batch 13/500, Loss: -0.3048, Policy Loss: -0.3048, KL Divergence: 0.0000
[2024-09-16 07:53:21,873] INFO - Batch 14/500, Loss: 0.0397, Policy Loss: 0.0397, KL Divergence: 0.0000
[2024-09-16 07:53:29,765] INFO - Batch 15/500, Loss: -0.1995, Policy Loss: -0.1995, KL Divergence: 0.0000
[2024-09-16 07:53:37,559] INFO - Batch 16/500, Loss: -0.2891, Policy Loss: -0.2891, KL Divergence: 0.0000
[2024-09-16 07:53:45,381] INFO - Batch 17/500, Loss: -0.1320, Policy Loss: -0.1320, KL Divergence: 0.0000
[2024-09-16 07:53:52,815] INFO - Batch 18/500, Loss: -0.4515, Policy Loss: -0.4515, KL Divergence: 0.0000
Traceback (most recent call last):
  File "/content/open1.py", line 601, in <module>
    train_open1(args)
  File "/content/open1.py", line 576, in train_open1
    final_response = env.play(initial_prompt)
  File "/content/open1.py", line 359, in play
    prompt, response = self.step(prompt)
  File "/content/open1.py", line 351, in step
    response = self.agent.generate_chain_of_thought(prompt, max_length=100)[0]
  File "/content/open1.py", line 144, in generate_chain_of_thought
    outputs = self.model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1874, in generate
    self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1266, in _validate_generated_length
    raise ValueError(
ValueError: Input length of input_ids is 100, but `max_length` is set to 100. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.
#!/usr/bin/env python
# coding: utf-8
"""
Open1: An Open-Source Implementation with Advanced Transformer Features
This script implements a novel transformer model using JAX and Flax,
integrating the following components:
- Infini-Attention with Compressive Memory and Dual Attention Modes
- FlashAttention with Block-wise Recomputation
- Mamba Transformers with Selective Gating and Recurrent Shortcuts
- Tree-Search-Powered Token Agents for Enhanced Reasoning
- kNN Caching and Memory-Bound Scaling
- Integration into the Overall Codebase with Training Loop
- Model Size Approximately 90 Million Parameters
Instructions:
- Ensure you have the required packages installed.
- Run the script: `python open1_jax.py`
- Use command-line arguments to enable or disable features.
Author: OpenAI Assistant
Date: 2023
"""
import os
import sys
import argparse
import logging
import math
import random
import numpy as np
from collections import defaultdict
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")
import jax
import jax.numpy as jnp
from jax import jit, grad, value_and_grad
from flax import linen as nn
from flax.training import train_state
import optax
from transformers import GPT2Tokenizer
from datasets import load_dataset
# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
# Device configuration
print(f"JAX devices: {jax.devices()}")
# =======================
# Command-Line Arguments
# =======================
def parse_args():
parser = argparse.ArgumentParser(description="Open1 Model Training and Evaluation with JAX")
parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
parser.add_argument('--max_length', type=int, default=256, help='Maximum sequence length')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
parser.add_argument('--log_dir', type=str, default='runs_jax', help='Directory for logs')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for optimizer')
parser.add_argument('--use_mcts', action='store_true', help='Enable MCTS for trajectory exploration')
parser.add_argument('--multi_agent', action='store_true', help='Enable multi-agent training')
parser.add_argument('--progressive_training', action='store_true', help='Enable progressive training pipeline')
parser.add_argument('--visualize', action='store_true', help='Enable visualization of LLM capabilities')
args = parser.parse_args()
return args
# =======================
# Logging Configuration
# =======================
def setup_logging():
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
# =======================
# Data Classes and Functions
# =======================
class CustomDataset:
def __init__(self, tokenizer, data, max_length):
self.tokenizer = tokenizer
self.data = data
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
prompt = self.data[idx]['prompt']
response = self.data[idx]['response']
# Concatenate prompt and response for training
input_text = prompt + " " + response
encoded = self.tokenizer.encode_plus(
input_text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='np',
)
input_ids = encoded['input_ids'].squeeze()
attention_mask = encoded['attention_mask'].squeeze()
return input_ids, attention_mask
# =======================
# Model Classes
# =======================
# Mamba Transformer Block with Infini-Attention and FlashAttention
class TransformerBlock(nn.Module):
d_model: int
n_heads: int
dropout_rate: float
use_flash_attention: bool = True
def setup(self):
self.attention = MultiHeadAttention(
d_model=self.d_model,
n_heads=self.n_heads,
use_flash_attention=self.use_flash_attention,
)
self.ln1 = nn.LayerNorm()
self.ln2 = nn.LayerNorm()
self.ffn = FeedForwardNetwork(
d_model=self.d_model,
dropout_rate=self.dropout_rate,
)
# Gating mechanism parameters
self.gate = nn.Dense(self.d_model)
def __call__(self, x, mask, memory=None):
# Layer Norm
x_norm = self.ln1(x)
# Multi-Head Attention with Infini-Attention
attn_output = self.attention(x_norm, mask, memory)
# Selective Gating with Recurrent Shortcuts (Mamba Transformers)
gate_value = nn.sigmoid(self.gate(x_norm))
gated_output = gate_value * x_norm + (1 - gate_value) * attn_output
# Feed-Forward Network
ffn_output = self.ffn(self.ln2(gated_output))
# Residual Connection
output = gated_output + ffn_output
return output
class MultiHeadAttention(nn.Module):
d_model: int
n_heads: int
use_flash_attention: bool = True
def setup(self):
assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
self.depth = self.d_model // self.n_heads
# Query, Key, Value projection layers
self.wq = nn.Dense(self.d_model)
self.wk = nn.Dense(self.d_model)
self.wv = nn.Dense(self.d_model)
self.dense = nn.Dense(self.d_model)
def __call__(self, x, mask, memory=None):
batch_size = x.shape[0]
# Project inputs to queries, keys, and values
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)
# Reshape for multi-head attention
q = q.reshape(batch_size, -1, self.n_heads, self.depth).transpose(0, 2, 1, 3)
k = k.reshape(batch_size, -1, self.n_heads, self.depth).transpose(0, 2, 1, 3)
v = v.reshape(batch_size, -1, self.n_heads, self.depth).transpose(0, 2, 1, 3)
if memory is not None:
# Incorporate memory for global attention
k_memory = memory['k_memory']
v_memory = memory['v_memory']
k = jnp.concatenate([k_memory, k], axis=2)
v = jnp.concatenate([v_memory, v], axis=2)
# Apply attention mechanism
if self.use_flash_attention:
attn_output = self.flash_attention(q, k, v, mask)
else:
attn_output = self.scaled_dot_product_attention(q, k, v, mask)
# Reshape and project output
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.d_model)
output = self.dense(attn_output)
return output
def flash_attention(self, q, k, v, mask):
# Simplified FlashAttention implementation
# Actual implementation requires custom kernels; here we simulate the behavior
d_k = q.shape[-1]
scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / math.sqrt(d_k)
# Apply mask
if mask is not None:
mask = mask[:, jnp.newaxis, jnp.newaxis, :]
scores = scores + (mask * -1e9)
# Softmax with block-wise recomputation
attn_weights = nn.softmax(scores, axis=-1)
output = jnp.einsum('bhqk,bhkd->bhqd', attn_weights, v)
return output
def scaled_dot_product_attention(self, q, k, v, mask):
# Standard attention mechanism
d_k = q.shape[-1]
scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / math.sqrt(d_k)
# Apply mask
if mask is not None:
mask = mask[:, jnp.newaxis, jnp.newaxis, :]
scores = scores + (mask * -1e9)
attn_weights = nn.softmax(scores, axis=-1)
output = jnp.einsum('bhqk,bhkd->bhqd', attn_weights, v)
return output
class FeedForwardNetwork(nn.Module):
d_model: int
dropout_rate: float
def setup(self):
self.dense1 = nn.Dense(self.d_model * 4)
self.dense2 = nn.Dense(self.d_model)
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(self, x):
x = self.dense1(x)
x = nn.relu(x)
x = self.dropout(x)
x = self.dense2(x)
return x
# Transformer Model
class TransformerModel(nn.Module):
vocab_size: int
max_length: int
d_model: int = 512
n_heads: int = 8
num_layers: int = 6
dropout_rate: float = 0.1
use_flash_attention: bool = True
def setup(self):
self.token_embedding = nn.Embed(self.vocab_size, self.d_model)
self.position_embedding = nn.Embed(self.max_length, self.d_model)
self.layers = [TransformerBlock(
d_model=self.d_model,
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
use_flash_attention=self.use_flash_attention
) for _ in range(self.num_layers)]
self.ln = nn.LayerNorm()
self.output_dense = nn.Dense(self.vocab_size)
def __call__(self, input_ids, attention_mask, memory=None):
batch_size, seq_length = input_ids.shape
x = self.token_embedding(input_ids)
positions = jnp.expand_dims(jnp.arange(seq_length), 0)
x = x + self.position_embedding(positions)
# Apply transformer layers
for layer in self.layers:
x = layer(x, attention_mask, memory)
x = self.ln(x)
logits = self.output_dense(x)
return logits
# Tree-Search-Powered Token Agent
class TokenAgent(nn.Module):
transformer_model: TransformerModel
beam_width: int = 3
max_length: int = 50
def __call__(self, input_ids, attention_mask):
# Implement beam search to simulate tree search
batch_size = input_ids.shape[0]
sequences = [input_ids]
scores = [0]
for _ in range(self.max_length):
all_candidates = []
for i in range(len(sequences)):
seq = sequences[i]
score = scores[i]
logits = self.transformer_model(seq, attention_mask)
log_probs = jax.nn.log_softmax(logits[:, -1, :], axis=-1)
top_k_probs, top_k_indices = jax.lax.top_k(log_probs, self.beam_width)
for j in range(self.beam_width):
candidate = {
'sequence': jnp.concatenate([seq, top_k_indices[:, j:j+1]], axis=1),
'score': score + top_k_probs[:, j]
}
all_candidates.append(candidate)
# Select the best candidates
ordered = sorted(all_candidates, key=lambda x: x['score'], reverse=True)
sequences = [c['sequence'] for c in ordered[:self.beam_width]]
scores = [c['score'] for c in ordered[:self.beam_width]]
return sequences[0]
# =======================
# kNN Caching Mechanism
# =======================
class kNNCaching:
def __init__(self, cache_size=1024, d_model=512):
self.cache_size = cache_size
self.d_model = d_model
self.keys = jnp.zeros((cache_size, d_model))
self.values = jnp.zeros((cache_size, d_model))
self.index = 0
def add(self, key, value):
# Add key-value pair to cache
idx = self.index % self.cache_size
self.keys = self.keys.at[idx].set(key)
self.values = self.values.at[idx].set(value)
self.index += 1
def retrieve(self, query, k=1):
# Retrieve the most similar keys
similarities = jnp.dot(self.keys, query.T)
top_k_indices = jnp.argsort(-similarities)[:k]
return self.values[top_k_indices]
# =======================
# Training Utilities
# =======================
def create_learning_rate_schedule(learning_rate, warmup_steps=1000):
def lr_schedule(step):
return learning_rate * min(1.0, step / warmup_steps)
return lr_schedule
def cross_entropy_loss(logits, labels, mask):
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
loss = loss * mask
return loss.sum() / mask.sum()
# =======================
# Training Function
# =======================
def train_open1(args):
setup_logging()
logging.info("Starting training of Open1 model with JAX.")
# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset from Hugging Face
dataset_name = 'wikitext' # Use WikiText dataset
raw_dataset = load_dataset(dataset_name, 'wikitext-2-raw-v1', split='train')
# Prepare data samples
data_samples = []
for entry in raw_dataset:
text = entry['text']
sentences = text.strip().split('. ')
for i in range(len(sentences) - 1):
prompt = sentences[i] + '.'
response = sentences[i + 1] + '.'
data_samples.append({'prompt': prompt, 'response': response})
# Limit dataset size
data_samples = data_samples[:10000]
# Create dataset and data loader
custom_dataset = CustomDataset(tokenizer, data_samples, max_length=args.max_length)
def data_generator():
idx = 0
while idx < len(custom_dataset):
batch_input_ids = []
batch_attention_mask = []
for _ in range(args.batch_size):
if idx >= len(custom_dataset):
break
input_ids, attention_mask = custom_dataset[idx]
batch_input_ids.append(input_ids)
batch_attention_mask.append(attention_mask)
idx += 1
yield {
'input_ids': np.stack(batch_input_ids),
'attention_mask': np.stack(batch_attention_mask)
}
# Initialize model
vocab_size = tokenizer.vocab_size
transformer_model = TransformerModel(
vocab_size=vocab_size,
max_length=args.max_length,
d_model=512,
n_heads=8,
num_layers=6,
dropout_rate=0.1,
use_flash_attention=True
)
# Initialize parameters
rng = jax.random.PRNGKey(0)
dummy_input_ids = jnp.ones((args.batch_size, args.max_length), dtype=jnp.int32)
dummy_attention_mask = jnp.ones((args.batch_size, args.max_length), dtype=jnp.int32)
params = transformer_model.init(rng, dummy_input_ids, dummy_attention_mask)
# Create optimizer
learning_rate_schedule = create_learning_rate_schedule(args.learning_rate)
optimizer = optax.adamw(learning_rate=learning_rate_schedule)
state = train_state.TrainState.create(apply_fn=transformer_model.apply, params=params, tx=optimizer)
# Training loop
for epoch in range(args.epochs):
logging.info(f"Epoch {epoch+1}/{args.epochs}")
data_iter = data_generator()
step = 0
for batch in data_iter:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = np.roll(input_ids, shift=-1, axis=1)
labels = labels.at[:, -1].set(tokenizer.eos_token_id)
mask = attention_mask
def loss_fn(params):
logits = transformer_model.apply(params, input_ids, attention_mask)
loss = cross_entropy_loss(logits, labels, mask)
return loss
loss_value, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
if step % 100 == 0:
logging.info(f"Step {step}, Loss: {loss_value:.4f}")
step += 1
# Save the trained model
# Note: Saving and loading models in JAX requires serialization
import pickle
with open('open1_model.pkl', 'wb') as f:
pickle.dump(state.params, f)
logging.info("Training completed and model saved.")
# =======================
# Main Function
# =======================
if __name__ == "__main__":
args = parse_args()
train_open1(args)
#!/usr/bin/env python
# coding: utf-8
"""
Open1: An Open-Source Implementation of O1 with Advanced Features
This script implements a simplified version of the OpenAI O1 model,
integrating the following components:
- Transformer-based language model using GPT-2 Small
- Chain-of-thought reasoning mechanism
- Tree-based trajectory exploration using Monte Carlo Tree Search (MCTS)
- Reinforcement Learning from Human Feedback (RLHF) using Proximal Policy Optimization (PPO)
- Self-play environments for LLMs
- Function calling practice in the training data
- Progressive training pipeline with increasing complexity
- Visualization using TensorBoard
- Hexagon chart rating of LLM capabilities
Instructions:
- Ensure you have the required packages installed.
- Run the script: `python open1.py`
- Use command-line arguments to enable or disable features.
Author: OpenAI Assistant
Date: 2023
"""
import os
import sys
import argparse
import logging
import math
import random
import numpy as np
from collections import defaultdict
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# =======================
# Command-Line Arguments
# =======================
def parse_args():
parser = argparse.ArgumentParser(description="Open1 Model Training and Evaluation")
parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
parser.add_argument('--log_dir', type=str, default='runs', help='Directory for TensorBoard logs')
parser.add_argument('--use_mcts', action='store_true', help='Enable MCTS for trajectory exploration')
parser.add_argument('--multi_agent', action='store_true', help='Enable multi-agent training')
parser.add_argument('--progressive_training', action='store_true', help='Enable progressive training pipeline')
parser.add_argument('--kl_coef', type=float, default=0.2, help='KL Coefficient for PPO')
parser.add_argument('--n_simulations', type=int, default=5, help='Number of MCTS simulations')
parser.add_argument('--visualize', action='store_true', help='Enable visualization of LLM capabilities')
parser.add_argument('--reward_model_epochs', type=int, default=1, help='Number of epochs to train the reward model')
args = parser.parse_args()
return args
# =======================
# Logging Configuration
# =======================
def setup_logging():
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
# =======================
# Data Classes and Functions
# =======================
class CustomDataset(Dataset):
def __init__(self, tokenizer, data, max_length):
self.tokenizer = tokenizer
self.data = data
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
prompt = self.data[idx]['prompt']
response = self.data[idx]['response']
# Concatenate prompt and response for training
input_text = prompt + response
input_ids = self.tokenizer.encode(input_text, truncation=True, max_length=self.max_length)
return torch.tensor(input_ids, dtype=torch.long)
def collate_fn(batch, tokenizer):
input_ids = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
return {'input_ids': input_ids, 'attention_mask': attention_mask}
# =======================
# Model Classes
# =======================
# Transformer-based Language Model with Chain-of-Thought
class Open1Model(nn.Module):
def __init__(self, model_name='gpt2'):
super(Open1Model, self).__init__()
self.model = GPT2LMHeadModel.from_pretrained(model_name)
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Ensure a unique pad token
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.to(device)
def forward(self, input_ids, attention_mask=None):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
return outputs
def generate_chain_of_thought(self, prompt, max_length=50, num_return_sequences=1):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(device)
# Generate with chain-of-thought reasoning
outputs = self.model.generate(
input_ids,
max_length=max_length,
num_return_sequences=num_return_sequences,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
generated_texts = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
return generated_texts
# Reward Model for RLHF
class RewardModel(nn.Module):
def __init__(self, model_name='gpt2', tokenizer=None):
super(RewardModel, self).__init__()
config = GPT2Config.from_pretrained(model_name)
self.model = GPT2LMHeadModel.from_pretrained(model_name)
self.tokenizer = tokenizer
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.to(device)
self.reward_head = nn.Linear(config.n_embd, 1)
self.reward_head.to(device)
def forward(self, input_ids, attention_mask=None):
outputs = self.model.transformer(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
# Take the hidden state of the last token
last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
reward = self.reward_head(last_hidden) # [batch_size, 1]
return reward.squeeze(-1) # [batch_size]
# Proximal Policy Optimization (PPO) Trainer
class PPOTrainer:
def __init__(self, policy_model, reward_model, tokenizer, kl_coef=0.2):
self.policy_model = policy_model
self.reward_model = reward_model
self.tokenizer = tokenizer
self.optimizer = AdamW(self.policy_model.model.parameters(), lr=1e-5)
self.kl_coef = kl_coef
def compute_advantages(self, rewards, values):
# Compute advantages (simplified)
advantages = rewards - values
return advantages
def ppo_step(self, batch):
# Unpack batch
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
old_log_probs = batch['old_log_probs'].to(device)
rewards = batch['rewards'].to(device)
# Get current policy outputs
outputs = self.policy_model.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # [batch_size, seq_length, vocab_size]
# Compute log probabilities
log_probs = nn.functional.log_softmax(logits, dim=-1) # [batch_size, seq_length, vocab_size]
# Gather log probs for the selected tokens
selected_log_probs = log_probs.gather(2, input_ids.unsqueeze(-1)).squeeze(-1) # [batch_size, seq_length]
# Compute policy ratio
ratios = torch.exp(selected_log_probs - old_log_probs) # [batch_size, seq_length]
# Compute advantages
values = torch.zeros_like(rewards).to(device) # Placeholder for value estimates
advantages = self.compute_advantages(rewards, values) # [batch_size]
advantages = advantages.unsqueeze(1) # [batch_size, 1]
# Compute PPO loss
surr1 = ratios * advantages # [batch_size, seq_length]
surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages # [batch_size, seq_length]
policy_loss = -torch.min(surr1, surr2).mean()
# Compute KL divergence for regularization
kl_div = torch.nn.functional.kl_div(old_log_probs, selected_log_probs, log_target=True, reduction='batchmean')
# Total loss
loss = policy_loss + self.kl_coef * kl_div
# Backpropagation
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item(), policy_loss.item(), kl_div.item()
# Monte Carlo Tree Search Node
class MCTSNode:
def __init__(self, state, parent=None, prior_prob=1.0):
self.state = state # Sequence of tokens (input_ids)
self.parent = parent
self.children = {}
self.visits = 0
self.value = 0.0
self.prior_prob = prior_prob
def is_leaf(self):
return len(self.children) == 0
# Monte Carlo Tree Search
class MCTS:
def __init__(self, model, tokenizer, reward_model, c_puct=1.0, n_simulations=50, max_depth=20):
self.model = model
self.tokenizer = tokenizer
self.reward_model = reward_model # Add reward_model
self.c_puct = c_puct
self.n_simulations = n_simulations
self.max_depth = max_depth
def search(self, state):
root = MCTSNode(state)
for _ in range(self.n_simulations):
node = root
# Selection
while not node.is_leaf():
node = self.select_child(node)
# Expansion
if len(node.state) < self.max_depth:
self.expand_node(node)
# Simulation
reward = self.simulate(node.state)
# Backpropagation
self.backpropagate(node, reward)
# Choose the best action
if root.children:
best_child = max(root.children.values(), key=lambda n: n.visits)
return best_child.state
else:
# If no children, return the root state
return root.state
def select_child(self, node):
total_visits = sum(child.visits for child in node.children.values())
best_score = -float('inf')
best_child = None
for child in node.children.values():
q_value = child.value / (child.visits + 1e-4)
u_value = self.c_puct * child.prior_prob * math.sqrt(total_visits + 1) / (1 + child.visits)
score = q_value + u_value
if score > best_score:
best_score = score
best_child = child
return best_child
def expand_node(self, node):
input_ids = torch.tensor(node.state).unsqueeze(0).to(device)
outputs = self.model.model(input_ids=input_ids)
logits = outputs.logits[:, -1, :] # [1, vocab_size]
probs = nn.functional.softmax(logits, dim=-1).squeeze() # [vocab_size]
top_k_probs, top_k_indices = torch.topk(probs, k=10)
for idx, prob in zip(top_k_indices, top_k_probs):
token_id = idx.item()
prior_prob = prob.item()
child_state = node.state + [token_id]
node.children[token_id] = MCTSNode(child_state, parent=node, prior_prob=prior_prob)
def simulate(self, state):
# Simulate to the end using the model
input_ids = torch.tensor(state).unsqueeze(0).to(device)
max_length = len(state) + 20
outputs = self.model.model.generate(
input_ids=input_ids,
max_length=max_length,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
generated_ids = outputs[0].tolist()
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
# Calculate reward using the reward model
reward = self.evaluate(generated_text)
return reward
def backpropagate(self, node, reward):
while node is not None:
node.visits += 1
node.value += reward
node = node.parent
def evaluate(self, text):
# Use the reward model to evaluate the text
input_ids = self.tokenizer.encode(text, return_tensors='pt').to(device)
attention_mask = (input_ids != self.tokenizer.pad_token_id).long().to(device)
with torch.no_grad():
reward = self.reward_model(input_ids=input_ids, attention_mask=attention_mask)
return reward.item()
# =======================
# Self-Play Environment
# =======================
class SelfPlayEnvironment:
def __init__(self, agent):
self.agent = agent
def step(self, prompt):
# Agent interacts with itself
response = self.agent.generate_chain_of_thought(prompt, max_length=100)[0]
# The agent then takes its own response as the next prompt
next_prompt = response
return next_prompt, response
def play(self, initial_prompt, max_turns=5):
prompt = initial_prompt
for _ in range(max_turns):
prompt, response = self.step(prompt)
return response
# =======================
# Visualization Functions
# =======================
def plot_capabilities(capabilities, epoch, writer):
import matplotlib.pyplot as plt
labels = list(capabilities.keys())
values = list(capabilities.values())
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
values += values[:1]
angles += angles[:1]
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
ax.plot(angles, values, 'o-', linewidth=2)
ax.fill(angles, values, alpha=0.25)
ax.set_thetagrids(np.degrees(angles[:-1]), labels)
ax.set_ylim(0, 1)
ax.set_title(f"LLM Capabilities at Step {epoch}")
ax.grid(True)
# Save the plot to TensorBoard
writer.add_figure('Capabilities', fig, epoch)
plt.close(fig)
# =======================
# Training Functions
# =======================
def train_reward_model(reward_model, tokenizer, data_samples, args):
logging.info("Starting training of the Reward Model.")
reward_model.train()
optimizer = AdamW(list(reward_model.model.parameters()) + list(reward_model.reward_head.parameters()), lr=1e-5)
# Create a dataset for the reward model
class RewardDataset(Dataset):
def __init__(self, tokenizer, data, max_length):
self.tokenizer = tokenizer
self.data = data
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
input_text = self.data[idx]['input_text']
reward = self.data[idx]['reward']
input_ids = self.tokenizer.encode(input_text, truncation=True, max_length=self.max_length)
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(reward, dtype=torch.float)
def reward_collate_fn(batch):
input_ids, rewards = zip(*batch)
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_mask = (input_ids != tokenizer.pad_token_id).long()
rewards = torch.stack(rewards)
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'rewards': rewards}
dataset = RewardDataset(tokenizer, data_samples, max_length=args.max_length)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=reward_collate_fn)
for epoch in range(args.reward_model_epochs):
logging.info(f"Reward Model Epoch {epoch+1}/{args.reward_model_epochs}")
total_loss = 0
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
rewards = batch['rewards'].to(device)
# Forward pass
predicted_rewards = reward_model(input_ids=input_ids, attention_mask=attention_mask)
# Loss computation
loss = nn.functional.mse_loss(predicted_rewards, rewards)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(data_loader)
logging.info(f"Average Reward Model Loss: {avg_loss:.4f}")
# Save the trained reward model
torch.save(reward_model.state_dict(), 'reward_model.pth')
logging.info("Reward Model training completed and saved.")
def train_open1(args):
setup_logging()
logging.info("Starting training of Open1 model.")
# Initialize models and tokenizer
policy_model = Open1Model(model_name='gpt2')
tokenizer = policy_model.tokenizer
reward_model = RewardModel(model_name='gpt2', tokenizer=tokenizer)
# Resize embeddings after adding pad_token
if tokenizer.pad_token is not None:
policy_model.model.resize_token_embeddings(len(tokenizer))
reward_model.model.resize_token_embeddings(len(tokenizer))
# Initialize PPO trainer
ppo_trainer = PPOTrainer(policy_model, reward_model, tokenizer, kl_coef=args.kl_coef)
# Initialize TensorBoard writer
writer = SummaryWriter(log_dir=args.log_dir)
# Load real dataset from Hugging Face
dataset_name = 'daily_dialog' # You can choose another dataset
raw_dataset = load_dataset(dataset_name, split='train')
# Prepare data samples
data_samples = []
for entry in raw_dataset:
# We will use the conversation as prompt and response
if 'dialog' in entry:
dialog = entry['dialog']
for i in range(len(dialog) - 1):
prompt = dialog[i]
response = dialog[i + 1]
data_samples.append({'prompt': prompt, 'response': response})
# Limit dataset size for demonstration purposes
data_samples = data_samples[:1000]
# Prepare data for reward model training
reward_data_samples = []
for sample in data_samples:
input_text = sample['prompt'] + sample['response']
reward = len(sample['response'].split()) / 50.0 # Simple heuristic
reward_data_samples.append({'input_text': input_text, 'reward': reward})
# Train the reward model first
train_reward_model(reward_model, tokenizer, reward_data_samples, args)
reward_model.eval() # Set reward model to evaluation mode
# Create dataset and data loader for the policy model
dataset = CustomDataset(tokenizer, data_samples, max_length=args.max_length)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))
# Training loop
for epoch in range(args.epochs):
logging.info(f"Epoch {epoch+1}/{args.epochs}")
for batch_idx, batch in enumerate(data_loader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# Generate initial response with chain-of-thought
generated_texts = []
for input_id in input_ids:
prompt = tokenizer.decode(input_id, skip_special_tokens=True)
if args.use_mcts:
# Use MCTS for trajectory exploration
mcts = MCTS(policy_model, tokenizer, reward_model, n_simulations=args.n_simulations, max_depth=args.max_length)
initial_state = input_id.tolist()
best_state = mcts.search(initial_state)
generated_text = tokenizer.decode(best_state, skip_special_tokens=True)
else:
# Regular generation
generated_text = policy_model.generate_chain_of_thought(prompt, max_length=args.max_length)[0]
generated_texts.append(generated_text)
# Prepare inputs for reward model
generated_ids = [tokenizer.encode(text, return_tensors='pt').squeeze(0) for text in generated_texts]
input_ids_gen = torch.nn.utils.rnn.pad_sequence(generated_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
attention_mask_gen = (input_ids_gen != tokenizer.pad_token_id).long().to(device)
# Get old log probabilities
with torch.no_grad():
outputs = policy_model.model(input_ids=input_ids_gen, attention_mask=attention_mask_gen)
logits = outputs.logits
log_probs = nn.functional.log_softmax(logits, dim=-1)
old_log_probs = log_probs.gather(2, input_ids_gen.unsqueeze(-1)).squeeze(-1)
# Get rewards from reward model
with torch.no_grad():
rewards = reward_model(input_ids=input_ids_gen, attention_mask=attention_mask_gen)
# Prepare batch
batch_data = {
'input_ids': input_ids_gen,
'attention_mask': attention_mask_gen,
'old_log_probs': old_log_probs,
'rewards': rewards,
}
# PPO optimization step
loss, policy_loss, kl_div = ppo_trainer.ppo_step(batch_data)
# Logging
logging.info(f"Batch {batch_idx+1}/{len(data_loader)}, Loss: {loss:.4f}, Policy Loss: {policy_loss:.4f}, KL Divergence: {kl_div:.4f}")
writer.add_scalar('Loss/Total', loss, epoch * len(data_loader) + batch_idx)
writer.add_scalar('Loss/Policy', policy_loss, epoch * len(data_loader) + batch_idx)
writer.add_scalar('Loss/KL_Divergence', kl_div, epoch * len(data_loader) + batch_idx)
# Visualize LLM capabilities
if args.visualize:
capabilities = {
'Reasoning': random.uniform(0, 1),
'Creativity': random.uniform(0, 1),
'Memory': random.uniform(0, 1),
'Planning': random.uniform(0, 1),
'Problem Solving': random.uniform(0, 1),
'Social Intelligence': random.uniform(0, 1),
}
plot_capabilities(capabilities, epoch * len(data_loader) + batch_idx, writer)
# (Optional) Self-Play Environment
if args.multi_agent:
env = SelfPlayEnvironment(policy_model)
initial_prompt = "Hello, how are you?"
final_response = env.play(initial_prompt)
# Use the final response to further train the model if desired
# For example, append to training data or use as additional training examples
# Include Function Calling Practice
if args.progressive_training:
function_call_prompt = "def add(a, b): return a + b\nadd("
generated_code = policy_model.generate_chain_of_thought(function_call_prompt, max_length=args.max_length)[0]
# Optionally, evaluate and train the model on code generation tasks
# For example, compare generated_code to expected code and adjust rewards accordingly
# Save the trained policy model after each epoch
policy_model.model.save_pretrained('open1_policy_model')
logging.info(f"Epoch {epoch+1} completed and model saved.")
writer.close()
logging.info("Training completed and model saved.")
# =======================
# Main Function
# =======================
if __name__ == "__main__":
args = parse_args()
setup_logging()
train_open1(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment