Created
April 5, 2023 20:37
-
-
Save csarron/a2a90f2b1a143aaf89b49f296118fa45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import json | |
import re | |
import string | |
import unicodedata | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModelForCausalLM | |
import torch | |
import time | |
import fire | |
from loguru import logger | |
from tqdm import tqdm | |
import random | |
log = logger.info | |
def remove_accents(input_str): | |
nfkd_form = unicodedata.normalize("NFKD", input_str) | |
return "".join([c for c in nfkd_form if not unicodedata.combining(c)]) | |
def normalize_answer(text: str) -> str: | |
# text = unicodedata.normalize("NFD", text) | |
text = remove_accents(text) | |
text = text.lower() | |
text = " ".join(c for c in text if c not in frozenset(string.punctuation)) | |
text = re.sub(r"\b(a|an|the)\b", " ", text) | |
text = " ".join(text.split()) | |
return text | |
def generate(tokenizer, prompt, model, max_new_tokens=10, temperature=0.8, top_p=0.95): | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) | |
outputs = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p) | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
return decoded[len(prompt):] | |
def setup_model(model_path, tokenizer_path, lora_path=None): | |
log("loading model...") | |
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16) | |
log("loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
added_tokens = tokenizer.add_special_tokens({"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<pad>"}) | |
if added_tokens > 0: | |
model.resize_token_embeddings(len(tokenizer)) | |
if lora_path is not None: | |
log("loading lora model..") | |
model = PeftModelForCausalLM.from_pretrained(model, lora_path, device_map="auto", torch_dtype=torch.float16) | |
model.to(dtype=torch.float16) | |
log(f"Mem needed: {model.get_memory_footprint() / 1024 / 1024 / 1024:.2f} GB") | |
return model, tokenizer | |
def extract_answer(text): | |
is_list_item = False | |
if text.startswith("1."): | |
is_list_item = True | |
text = text.replace("1. ", "") # TODO: still needs to properly extract answers | |
end_idx = len(text) | |
for char in ['\n', '.', ',']: | |
idx = text.find(char) | |
if idx != -1 and idx < end_idx: | |
end_idx = min(end_idx, idx) | |
answer = text[:end_idx] | |
if answer.endswith("2") and is_list_item: | |
answer = answer[-2:].strip() | |
return answer | |
def zero_shot_close_qa(dataset_file, model_path, tokenizer_path, lora_path=None, max_new_tokens=30, temperature=0.8, top_p=0.95): | |
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path) | |
log(f"loading data from {dataset_file}...") | |
qa_data = [json.loads(x) for x in open(dataset_file)] | |
start_time = time.time() | |
correct_count = 0 | |
p_bar = tqdm(qa_data) | |
for qa_item in p_bar: | |
question = qa_item["question"] | |
answers = qa_item["answers"] | |
prompt = f"Answer these questions: \nQ: {question}\nA: " | |
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p) | |
# pred_ans = extract_answer(pred_text) | |
# is_correct = normalize_answer(pred_ans) in frozenset(normalize_answer(ans) for ans in answers) | |
pred_ans = normalize_answer(pred_text) | |
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers) | |
correct_count += int(is_correct) | |
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
duration = time.time() - start_time | |
duration_str = datetime.timedelta(seconds=duration) | |
acc = correct_count / len(qa_data) * 100 | |
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!") | |
def zero_shot_open_qa(dataset_file, model_path, tokenizer_path, lora_path=None, top_k=5, max_new_tokens=30, temperature=0.8, top_p=0.95): | |
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path) | |
log(f"loading data from {dataset_file}...") | |
qa_data = [json.loads(x) for x in open(dataset_file)] | |
start_time = time.time() | |
correct_count = 0 | |
p_bar = tqdm(qa_data) | |
for qa_item in p_bar: | |
question = qa_item["question"] | |
answers = qa_item["answers"] | |
contexts = qa_item["ctxs"][:top_k] | |
passages = [c["text"] for c in contexts] | |
psg_text = "\n".join(passages) | |
prompt = f"Given the following passages: \n{psg_text}\nAnswer the question: {question}\nThe answer is " | |
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p) | |
pred_ans = normalize_answer(pred_text) | |
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers) | |
correct_count += int(is_correct) | |
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
duration = time.time() - start_time | |
duration_str = datetime.timedelta(seconds=duration) | |
acc = correct_count / len(qa_data) * 100 | |
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!") | |
def few_shot_close_qa(dataset_file, train_file, model_path, tokenizer_path, lora_path=None, shot=5, seed=0, max_new_tokens=30, temperature=0.8, top_p=0.95): | |
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path) | |
log(f"loading data from {dataset_file}...") | |
qa_data = [json.loads(x) for x in open(dataset_file)] | |
# sample shot number of examples from train_data | |
train_data = [json.loads(x) for x in open(train_file)] | |
random.seed(seed) | |
sample_train = random.sample(train_data, shot) | |
sample_text = "\n".join([f'Q: {x["question"]}\nA: {x["answers"][0]}' for x in sample_train]) | |
log(f"{shot}-shot examples: {sample_text}") | |
start_time = time.time() | |
correct_count = 0 | |
p_bar = tqdm(qa_data) | |
for qa_item in p_bar: | |
question = qa_item["question"] | |
answers = qa_item["answers"] | |
prompt = f"Answer these questions: \n{sample_text}\nQ: {question}\nA: " | |
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p) | |
pred_ans = normalize_answer(pred_text) | |
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers) | |
correct_count += int(is_correct) | |
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
duration = time.time() - start_time | |
duration_str = datetime.timedelta(seconds=duration) | |
acc = correct_count / len(qa_data) * 100 | |
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!") | |
def few_shot_open_qa(dataset_file, train_file, model_path, tokenizer_path, lora_path=None, top_k=5, shot=5, seed=0, max_new_tokens=30, temperature=0.8, top_p=0.95): | |
model, tokenizer = setup_model(model_path, tokenizer_path, lora_path) | |
log(f"loading data from {dataset_file}...") | |
qa_data = [json.loads(x) for x in open(dataset_file)] | |
# sample shot number of examples from train_data | |
train_data = [json.loads(x) for x in open(train_file)] | |
random.seed(seed) | |
sample_train = random.sample(train_data, shot) | |
sample_texts = [] | |
for item in sample_train: | |
psg_text = "\n".join([c["text"] for c in item["ctxs"][:top_k]]) | |
sample_t = f'{psg_text}\nQ: {item["question"]}\nA: {item["answers"][0]}' | |
sample_texts.append(sample_t) | |
sample_text = "\n".join(sample_texts) | |
log(f"{shot}-shot examples: {sample_text}") | |
start_time = time.time() | |
correct_count = 0 | |
p_bar = tqdm(qa_data) | |
for qa_item in p_bar: | |
question = qa_item["question"] | |
answers = qa_item["answers"] | |
prompt = f"\n{sample_text}\nQ: {question}\nA: " | |
pred_text = generate(tokenizer, prompt, model, max_new_tokens, temperature, top_p) | |
pred_ans = normalize_answer(pred_text) | |
is_correct = any(normalize_answer(ans) in pred_ans for ans in answers) | |
correct_count += int(is_correct) | |
# p_bar.set_description(f"q={question}, pred={pred_ans}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
p_bar.set_description(f"q={question}, {answers=}, correct={is_correct}, num_correct={correct_count}, {pred_text=}".replace("\n", "")) | |
duration = time.time() - start_time | |
duration_str = datetime.timedelta(seconds=duration) | |
acc = correct_count / len(qa_data) * 100 | |
log(f"processed {len(qa_data)} examples, all done in {duration_str}s, {acc=:.2f}!") | |
if __name__ == "__main__": | |
fire.Fire() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment