Created
April 17, 2024 20:15
-
-
Save lapp0/e7d17884ed76669194c36e7fb3f64040 to your computer and use it in GitHub Desktop.
Online AI Feedback T5
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import Dataset, load_from_disk | |
from transformers import TrainingArguments | |
from transformers.trainer_utils import EvalLoopOutput | |
from unsloth import FastLanguageModel | |
import random | |
from huggingface_hub import create_repo | |
from scipy.spatial.distance import cosine | |
from sentence_transformers import SentenceTransformer | |
import statistics | |
from typing import Dict, Union, Any | |
import torch | |
from torch.utils.data import DataLoader | |
import trl | |
class DynamicDataLoader: | |
def __init__(self, base_dataloader, mutate_fn): | |
self.base_dataloader = base_dataloader | |
self.mutate_fn = mutate_fn | |
def __iter__(self): | |
for batch in self.base_dataloader.__iter__(): | |
yield self.mutate_fn(batch) | |
def __len__(self): | |
return len(self.base_dataloader) | |
class eval_mode: | |
def __init__(self, model): | |
self.model = model | |
def __enter__(self): | |
self.was_training = self.model.training | |
if self.was_training: | |
FastLanguageModel.for_inference(self.model) | |
self.model.eval() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self.was_training: | |
FastLanguageModel.for_training(self.model) | |
self.model.train() | |
class OAIFTrainer(trl.DPOTrainer): | |
def __init__( | |
self, | |
*args, | |
oaif_annotator, | |
train_dataset, | |
eval_oaif_annotator=None, | |
eval_dataset=None, | |
**kwargs | |
): | |
# insert mock chosen / rejected | |
train_dataset = self._pre_patch_dataset(train_dataset) | |
if eval_dataset is not None: | |
eval_dataset = self._pre_patch_dataset(eval_dataset) | |
super().__init__(*args, train_dataset=train_dataset, eval_dataset=eval_dataset, **kwargs) | |
self.oaif_annotator = oaif_annotator | |
self.eval_oaif_annotator = eval_oaif_annotator if eval_oaif_annotator else oaif_annotator | |
@staticmethod | |
def _pre_patch_dataset(ds): | |
assert "chosen" not in ds | |
assert "rejected" not in ds | |
def add_columns(example): | |
example['chosen'] = '' | |
example['rejected'] = '' | |
return example | |
return ds.map(add_columns) | |
@staticmethod | |
def _post_patch_dataset(ds): | |
nullified_keys = [ | |
"chosen", "chosen_input_ids", "chosen_attention_mask", "chosen_labels", | |
"rejected", "rejected_input_ids", "rejected_attention_mask", "rejected_labels", | |
] | |
for i in range(len(ds)): | |
for nullified_key in nullified_keys: | |
ds[nullified_key][i] = None | |
@staticmethod | |
def _batch_list(lst, batch_size): | |
return [ | |
lst[i:i + batch_size] | |
for i in range(0, len(lst), batch_size) | |
] | |
@staticmethod | |
def rstrip_pad_token(t, pad_token_id): | |
non_pad_indices = (t != pad_token_id).nonzero(as_tuple=True)[0] | |
if len(non_pad_indices) == 0: | |
return t | |
last_non_pad_index = non_pad_indices[-1].item() | |
return t[:last_non_pad_index + 1] | |
def oaif_label_rewards(self, batch): | |
""" | |
Use model to annotate rewards, then use self.oaif_annotator to label rejects and chosen | |
""" | |
# parameters from paper | |
# https://www.semanticscholar.org/reader/04d64be16fb402f28348faffef484bd419c8bd8f | |
#temperature = 0.7 | |
num_return_sequences = 4 | |
top_p = 0.9 | |
# deviates from paper | |
temperature = 0.9 | |
batch_size = 4 | |
response_groups = [] | |
with eval_mode(model): | |
for queries in self._batch_list(batch["prompt"], batch_size): | |
masked_inputs = self.tokenizer(queries, padding=True, return_tensors="pt").to("cuda") | |
generation = self.model.generate( | |
**masked_inputs, | |
num_return_sequences=num_return_sequences, | |
pad_token_id=self.tokenizer.pad_token_id, | |
# TODO: implement use of self.oaif_generation_kwargs for below args | |
max_new_tokens=128, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
) | |
response_group_chunk = generation[:, masked_inputs.input_ids.shape[1]:] | |
response_group_chunk = response_group_chunk.view( | |
-1, | |
num_return_sequences, | |
response_group_chunk.shape[1] | |
) | |
for responses in list(response_group_chunk): | |
response_groups.append(tuple([ | |
self.rstrip_pad_token(resp, self.tokenizer.pad_token_id) | |
for resp in responses | |
])) | |
# generate annotations | |
base_annotations = self.oaif_annotator( | |
batch, | |
response_groups, | |
self.tokenizer, | |
) | |
annotated_ds = Dataset.from_dict(base_annotations).map(self.tokenize_row) | |
collated_ds = self.data_collator(annotated_ds) | |
# hack | |
collated_ds["attention_mask"] = collated_ds["prompt_attention_mask"] | |
collated_ds["input_ids"] = collated_ds["prompt_input_ids"] | |
return collated_ds | |
def get_train_dataloader(self): | |
dataloader = super().get_train_dataloader() | |
mutate_fn = lambda batch: {**batch, **self.oaif_label_rewards(batch)} | |
return DynamicDataLoader(dataloader, mutate_fn) | |
def evaluation_loop(self, dataloader, *args, metric_key_prefix="eval", **kwargs): | |
""" | |
Modified evaluate() which calculates the cosine similarity | |
hacky, this should be part of SharpenedCosineSimilarityAnnotator, not this class | |
""" | |
greedy_responses = [] | |
true_responses = [] | |
with eval_mode(model): | |
for batch in dataloader: | |
queries = batch["prompt"] | |
true_responses += batch["resolved_prompt"] | |
masked_inputs = self.tokenizer(queries, padding=True, return_tensors="pt").to("cuda") | |
generation = self.model.generate( | |
**masked_inputs, | |
pad_token_id=self.tokenizer.pad_token_id, | |
max_new_tokens=128, | |
do_sample=False, | |
) | |
greedy_responses += list(generation[:, masked_inputs.input_ids.shape[1]:]) | |
observed_responses = [ | |
self.tokenizer.decode(resp, skip_special_tokens=True).strip() | |
for resp in greedy_responses | |
] | |
assert len(observed_responses) == len(true_responses) | |
rewards = list(map(float, | |
self.oaif_annotator.get_reward(observed_responses, true_responses) | |
)) | |
exact_prefix = f"{metric_key_prefix}_oaif_exact" | |
q1, q2, q3 = statistics.quantiles(rewards, n=4) | |
oaif_metrics = { | |
f"{exact_prefix}_mean": statistics.mean(rewards), | |
f"{exact_prefix}_std_dev": statistics.stdev(rewards), | |
f"{exact_prefix}_min": min(rewards), | |
f"{exact_prefix}_q1": q1, | |
f"{exact_prefix}_q2": q2, | |
f"{exact_prefix}_q3": q3, | |
f"{exact_prefix}_max": max(rewards), | |
} | |
print() | |
for i in range(4): | |
print(f"rewards[{i}]:", rewards[i]) | |
print(f"\tobserved_prompt[{i}]:", observed_responses[i]) | |
print(f"\ttrue_prompt[{i}]:", true_responses[i]) | |
max_reward_idx = rewards.index(max(rewards)) | |
print(f"rewards[max]:", rewards[max_reward_idx]) | |
print(f"\tobserved_prompt[max]:", observed_responses[max_reward_idx]) | |
print(f"\ttrue_prompt[max]:", true_responses[max_reward_idx]) | |
print() | |
return EvalLoopOutput( | |
predictions=None, | |
label_ids=None, | |
metrics=oaif_metrics, | |
num_samples=len(true_responses), | |
) | |
class SharpenedCosineSimilarityAnnotator: | |
def __init__(self, reference_prompts, embed_model_id='sentence-transformers/sentence-t5-base'): | |
self.embed_model = SentenceTransformer(embed_model_id) | |
self.reference_prompts = reference_prompts | |
self.embedding_cache = {} | |
def get_embedding(self, s): | |
if s not in self.embedding_cache: | |
self.embedding_cache[s] = self.embed_model.encode([s])[0] | |
return self.embedding_cache[s] | |
def get_sharpened_cos_sim(self, observed_prompts, true_prompts): | |
observed_embeddings = [self.get_embedding(p) for p in observed_prompts] | |
true_embeddings = [self.get_embedding(p) for p in true_prompts] | |
similarities = [ | |
(1 - cosine(o, r))**3 | |
for o, r in zip(observed_embeddings, true_embeddings) | |
] | |
return torch.tensor(similarities) | |
def get_closest_n_similarity(self, observed_prompt, n=None): | |
""" | |
get the reward of the n'th most similar prompt | |
""" | |
if n is None: | |
n = int(len(self.reference_prompts) / 4) | |
observed_embedding = self.get_embedding(observed_prompt) | |
reference_embeddings = [self.get_embedding(tp) for tp in self.reference_prompts] | |
similarities = torch.tensor([ | |
(1 - cosine(observed_embedding, ref_emb))**3 | |
for ref_emb in reference_embeddings | |
]) | |
top_n_sim, _ = torch.topk(similarities, n) | |
return torch.min(top_n_sim) | |
def get_reward(self, observed_prompts, true_prompts): | |
exact_rewards = self.get_sharpened_cos_sim( | |
list(map(str.strip, observed_prompts)), | |
true_prompts | |
) | |
# 3/4th reward if \n in response | |
return [ | |
r * 3 / 4 if "\n" in op.strip() else r | |
for op, r in zip(observed_prompts, exact_rewards) | |
] | |
def __call__(self, batch, response_groups, tokenizer): | |
chosen = [] | |
rejected = [] | |
for resp_ids_group, resolved_prompt in zip(response_groups, batch["resolved_prompt"]): | |
best = None | |
best_reward = None | |
worst = None | |
worst_reward = None | |
for resp_ids in resp_ids_group: | |
response_string = tokenizer.decode(resp_ids, skip_special_tokens=True) | |
reward = self.get_reward([response_string], [resolved_prompt])[0] | |
if best_reward is None or reward > best_reward: | |
best_reward = reward | |
best = tokenizer.decode(resp_ids, skip_special_tokens=False) | |
if worst_reward is None or reward < worst_reward: | |
worst_reward = reward | |
worst = tokenizer.decode(resp_ids, skip_special_tokens=False) | |
chosen.append(best) | |
rejected.append(worst) | |
return { | |
"prompt": batch["prompt"], | |
"chosen": chosen, | |
"rejected": rejected, | |
} | |
def get_unsloth_model(base_model_name, max_seq_length=2048): | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=base_model_name, | |
max_seq_length=max_seq_length, | |
dtype=None, | |
load_in_4bit=True, | |
) | |
peft_model = FastLanguageModel.get_peft_model( | |
model, | |
target_modules=[ | |
"q_proj", "v_proj", "k_proj", "o_proj", # attention (self_attn) | |
"gate_proj", "down_proj", "up_proj", # FFN (mlp) | |
], | |
r=8, | |
lora_alpha=32, | |
lora_dropout = 0, | |
bias = "none", | |
use_gradient_checkpointing = True, | |
max_seq_length=max_seq_length, | |
) | |
tokenizer.padding_side = "left" | |
return peft_model, tokenizer | |
def get_trainer( | |
model, | |
tokenizer, | |
train_dataset, | |
eval_dataset, | |
output_dir, | |
hub_repo_id, | |
# parameters from paper | |
# https://www.semanticscholar.org/reader/04d64be16fb402f28348faffef484bd419c8bd8f | |
train_batch_size=2, | |
gradient_accumulation_steps=8, | |
learning_rate=1e-5, | |
beta=0.1, | |
eval_batch_size=8, | |
loss_type="dpo", | |
warmup_steps=0, | |
max_prompt_length=2048, | |
max_length=2048 + 128, | |
seed=42 | |
): | |
training_args = TrainingArguments( | |
learning_rate=learning_rate, | |
lr_scheduler_type="constant", | |
warmup_steps=warmup_steps, | |
optim="paged_adamw_8bit", | |
per_device_train_batch_size=train_batch_size, | |
per_device_eval_batch_size=eval_batch_size, | |
report_to="tensorboard", | |
logging_steps=1, | |
evaluation_strategy="steps", | |
eval_steps=12, | |
save_strategy="steps", | |
save_steps=12, | |
push_to_hub=True, | |
hub_private_repo=True, | |
hub_model_id=hub_repo_id, | |
hub_strategy="every_save", | |
#load_best_model_at_end=True, | |
gradient_checkpointing=True, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
gradient_checkpointing_kwargs=dict(use_reentrant=True), | |
output_dir=output_dir, | |
bf16=torch.cuda.is_bf16_supported(), | |
fp16=not torch.cuda.is_bf16_supported(), | |
seed=seed, | |
# experients suggest learning past 1.5 epochs is mostly useless | |
num_train_epochs=3, | |
# hack | |
metric_for_best_model=None, | |
# stabilize | |
max_grad_norm=10.0, | |
# label | |
run_name="oaif_standard", | |
) | |
return OAIFTrainer( | |
model, | |
ref_model=None, # TODO: would be nice to eval against base model by stripping the adapters | |
args=training_args, | |
beta=beta, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
tokenizer=tokenizer, | |
max_prompt_length=max_prompt_length, | |
max_length=max_length, | |
oaif_annotator=SharpenedCosineSimilarityAnnotator(train_dataset["resolved_prompt"]), | |
eval_oaif_annotator=SharpenedCosineSimilarityAnnotator(eval_dataset["resolved_prompt"]), | |
) | |
prompt_template = """<|im_start|>user | |
You are a reverse prompt generator. The user will provide two texts, "Original Text" and "Rewritten Text". You will determine the "Transformation Prompt" which was applied to "Original Text" with a language model to generate "Rewritten Text". Analyze the changes in style, theme, etc to determine whach "Transformation Prompt" was used with a language model to convert "Original Text" into "Rewritten Text"<|im_end|> | |
<|im_start|>user | |
Original Text: | |
''' | |
{} | |
''' | |
Rewritten Text: | |
''' | |
{} | |
''' | |
What is the Transformation Prompt which was applied to modify the input text into the transformed output text?<|im_end|> | |
<|im_start|>assistant | |
Transformation Prompt: | |
""" | |
def format_query(example): | |
return prompt_template.format( | |
example['original_text'].strip(), | |
example['rewritten_text'].strip(), | |
) | |
def get_datasets(): | |
def coll(example): | |
return { | |
"resolved_prompt": example["prompt"], | |
"prompt": format_query(example), | |
"chosen": "", | |
"rejected": "", | |
} | |
# Loading the datasets | |
ds = load_from_disk("dataset_v43") | |
train_dataset = ds["oaif_train"] | |
eval_dataset = ds["eval"] | |
# Applying transformations | |
train_dataset = train_dataset.map(coll) | |
train_dataset = train_dataset.remove_columns([ | |
col for col in train_dataset.column_names | |
if col not in ['prompt', 'resolved_prompt', "chosen", "rejected"] | |
]) | |
train_dataset = train_dataset.shuffle(seed=42) | |
eval_dataset = eval_dataset.map(coll) | |
eval_dataset = eval_dataset.remove_columns([ | |
col for col in eval_dataset.column_names | |
if col not in ['prompt', 'resolved_prompt', "chosen", "rejected"] | |
]) | |
return train_dataset, eval_dataset | |
if __name__ == "__main__": | |
output_dir = "oaif_v43.1" | |
base_model_name = "sft_v43_merged" | |
hub_repo_id = f"lapp0/{output_dir}" | |
model, tokenizer = get_unsloth_model(base_model_name=base_model_name) | |
train_ds, eval_ds = get_datasets() | |
trainer = get_trainer( | |
model, | |
tokenizer, | |
train_ds, | |
eval_ds, | |
output_dir=output_dir, | |
hub_repo_id=hub_repo_id | |
) | |
trainer.evaluate() # step 0 | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment